diff --git a/jax/_src/numpy/index_tricks.py b/jax/_src/numpy/index_tricks.py index b34be9e3223e..ff3b1caa5448 100644 --- a/jax/_src/numpy/index_tricks.py +++ b/jax/_src/numpy/index_tricks.py @@ -119,14 +119,13 @@ class _Ogrid: Multiple slices can be used to create sparse grids of indices: >>> jnp.ogrid[:2, :3] - [Array([[0], - [1]], dtype=int32), - Array([[0, 1, 2]], dtype=int32)] + (Array([[0], + [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32)) """ def __getitem__( self, key: slice | tuple[slice, ...] - ) -> Array | list[Array]: + ) -> Array | tuple[Array, ...]: if isinstance(key, slice): return _make_1d_grid_from_slice(key, op_name="ogrid") output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 96d100e5534b..179b104b5664 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -976,7 +976,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] = def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, range: Sequence[None | Array | Sequence[ArrayLike]] | None = None, weights: ArrayLike | None = None, - density: bool | None = None) -> tuple[Array, list[Array]]: + density: bool | None = None) -> tuple[Array, tuple[Array, ...]]: """Compute an N-dimensional histogram. JAX implementation of :func:`numpy.histogramdd`. @@ -1079,7 +1079,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, for norm in ix_(*dedges): hist /= norm - return hist, bin_edges_by_dim + return hist, tuple(bin_edges_by_dim) @export @@ -2994,7 +2994,7 @@ def broadcast_shapes(*shapes): @export -def broadcast_arrays(*args: ArrayLike) -> list[Array]: +def broadcast_arrays(*args: ArrayLike) -> tuple[Array, ...]: """Broadcast arrays to a common shape. JAX implementation of :func:`numpy.broadcast_arrays`. JAX uses NumPy-style @@ -3015,7 +3015,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]: >>> x = jnp.arange(3) >>> y = jnp.int32(1) >>> jnp.broadcast_arrays(x, y) - [Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)] + (Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)) >>> x = jnp.array([[1, 2, 3]]) >>> y = jnp.array([[10], @@ -5097,17 +5097,17 @@ def block(arrays: ArrayLike | list[ArrayLike]) -> Array: @overload -def atleast_1d() -> list[Array]: +def atleast_1d() -> tuple[()]: ... @overload def atleast_1d(x: ArrayLike, /) -> Array: ... @overload -def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: +def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ... @export @api.jit -def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: +def atleast_1d(*arys: ArrayLike) -> Array | tuple[Array, ...]: """Convert inputs to arrays with at least 1 dimension. JAX implementation of :func:`numpy.atleast_1d`. @@ -5139,30 +5139,30 @@ def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: Array([0, 1, 2, 3], dtype=int32) Multiple arguments can be passed to the function at once, in which - case a list of results is returned: + case a tuple of results is returned: >>> jnp.atleast_1d(x, y) - [Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)] + (Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)) """ util.check_arraylike("atleast_1d", *arys, emit_warning=True) if len(arys) == 1: return array(arys[0], copy=False, ndmin=1) else: - return [array(arr, copy=False, ndmin=1) for arr in arys] + return tuple(array(arr, copy=False, ndmin=1) for arr in arys) @overload -def atleast_2d() -> list[Array]: +def atleast_2d() -> tuple[()]: ... @overload def atleast_2d(x: ArrayLike, /) -> Array: ... @overload -def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: +def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ... @export @api.jit -def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: +def atleast_2d(*arys: ArrayLike) -> Array | tuple[Array, ...]: """Convert inputs to arrays with at least 2 dimensions. JAX implementation of :func:`numpy.atleast_2d`. @@ -5202,31 +5202,31 @@ def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: [1., 1., 1.]], dtype=float32) Multiple arguments can be passed to the function at once, in which - case a list of results is returned: + case a tuple of results is returned: >>> jnp.atleast_2d(x, y) - [Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)] + (Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)) """ # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("atleast_2d", *arys, emit_warning=True) if len(arys) == 1: return array(arys[0], copy=False, ndmin=2) else: - return [array(arr, copy=False, ndmin=2) for arr in arys] + return tuple(array(arr, copy=False, ndmin=2) for arr in arys) @overload -def atleast_3d() -> list[Array]: +def atleast_3d() -> tuple[()]: ... @overload def atleast_3d(x: ArrayLike, /) -> Array: ... @overload -def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: +def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ... @export @api.jit -def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: +def atleast_3d(*arys: ArrayLike) -> Array | tuple[Array, ...]: """Convert inputs to arrays with at least 3 dimensions. JAX implementation of :func:`numpy.atleast_3d`. @@ -5289,7 +5289,7 @@ def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: arr = lax.expand_dims(arr, dimensions=(2,)) return arr else: - return [atleast_3d(arr) for arr in arys] + return tuple(atleast_3d(arr) for arr in arys) @export @@ -6044,7 +6044,7 @@ def _arange_dynamic( @export def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, - indexing: str = 'xy') -> list[Array]: + indexing: str = 'xy') -> tuple[Array, ...]: """Construct N-dimensional grid arrays from N 1-dimensional vectors. JAX implementation of :func:`numpy.meshgrid`. @@ -6120,7 +6120,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, output = [lax.broadcast_in_dim(a, _a_shape(i, a), (i,)) for i, a, in enumerate(args)] if indexing == "xy" and len(args) >= 2: output[0], output[1] = output[1], output[0] - return output + return tuple(output) @export diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 53a78cf6c8ac..5a38e1d7051a 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -242,15 +242,15 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]: @api.jit(inline=True) -def _broadcast_arrays(*args: ArrayLike) -> list[Array]: +def _broadcast_arrays(*args: ArrayLike) -> tuple[Array, ...]: """Like Numpy's broadcast_arrays but doesn't return views.""" avals = [core.shaped_abstractify(arg) for arg in args] shapes = [a.shape for a in avals] if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes): - return [lax.asarray(arg) for arg in args] + return tuple(lax.asarray(arg) for arg in args) result_shape = lax.broadcast_shapes(*shapes) result_sharding = lax.broadcast_shardings(*avals) - return [_broadcast_to(arg, result_shape, result_sharding) for arg in args] + return tuple(_broadcast_to(arg, result_shape, result_sharding) for arg in args) def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 2c0ce335441b..d9cd2378fd82 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -201,25 +201,25 @@ def atan(x: ArrayLike, /) -> Array: ... def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def atanh(x: ArrayLike, /) -> Array: ... @overload -def atleast_1d() -> list[Array]: ... +def atleast_1d() -> tuple[()]: ... @overload def atleast_1d(x: ArrayLike, /) -> Array: ... @overload -def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ... @overload -def atleast_2d() -> list[Array]: ... +def atleast_2d() -> tuple[()]: ... @overload def atleast_2d(x: ArrayLike, /) -> Array: ... @overload -def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ... @overload -def atleast_3d() -> list[Array]: ... +def atleast_3d() -> tuple[()]: ... @overload def atleast_3d(x: ArrayLike, /) -> Array: ... @overload -def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... +def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ... @overload def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ..., @@ -247,7 +247,7 @@ def blackman(M: int) -> Array: ... def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ... bool: Any bool_: Any -def broadcast_arrays(*args: ArrayLike) -> list[Array]: ... +def broadcast_arrays(*args: ArrayLike) -> tuple[Array, ...]: ... @overload def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... @@ -543,7 +543,7 @@ def histogramdd( range: Sequence[None | Array | Sequence[ArrayLike]] | None = ..., weights: ArrayLike | None = ..., density: builtins.bool | None = ..., -) -> tuple[Array, list[Array]]: ... +) -> tuple[Array, tuple[Array, ...]]: ... def hsplit( ary: ArrayLike, indices_or_sections: int | ArrayLike ) -> list[Array]: ... @@ -675,7 +675,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ..., out: None = ..., overwrite_input: builtins.bool = ..., keepdims: builtins.bool = ...) -> Array: ... def meshgrid(*xi: ArrayLike, copy: builtins.bool = ..., sparse: builtins.bool = ..., - indexing: str = ...) -> list[Array]: ... + indexing: str = ...) -> tuple[Array, ...]: ... mgrid: _Mgrid def min(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,