Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions jax/_src/numpy/index_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ class _Ogrid:
Multiple slices can be used to create sparse grids of indices:

>>> jnp.ogrid[:2, :3]
[Array([[0],
(Array([[0],
[1]], dtype=int32),
Array([[0, 1, 2]], dtype=int32)]
Array([[0, 1, 2]], dtype=int32),)
"""

def __getitem__(
self, key: slice | tuple[slice, ...]
) -> Array | list[Array]:
) -> Array | tuple[Array, ...]:
if isinstance(key, slice):
return _make_1d_grid_from_slice(key, op_name="ogrid")
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key)
Expand Down
44 changes: 22 additions & 22 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] =
def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
range: Sequence[None | Array | Sequence[ArrayLike]] | None = None,
weights: ArrayLike | None = None,
density: bool | None = None) -> tuple[Array, list[Array]]:
density: bool | None = None) -> tuple[Array, tuple[Array, ...]]:
"""Compute an N-dimensional histogram.

JAX implementation of :func:`numpy.histogramdd`.
Expand Down Expand Up @@ -1079,7 +1079,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
for norm in ix_(*dedges):
hist /= norm

return hist, bin_edges_by_dim
return hist, tuple(bin_edges_by_dim)


@export
Expand Down Expand Up @@ -2994,7 +2994,7 @@ def broadcast_shapes(*shapes):


@export
def broadcast_arrays(*args: ArrayLike) -> list[Array]:
def broadcast_arrays(*args: ArrayLike) -> tuple[Array, ...]:
"""Broadcast arrays to a common shape.

JAX implementation of :func:`numpy.broadcast_arrays`. JAX uses NumPy-style
Expand Down Expand Up @@ -3031,7 +3031,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]:
.. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html
"""
args = util.ensure_arraylike_tuple("broadcast_arrays", args)
return util._broadcast_arrays(*args)
return tuple(util._broadcast_arrays(*args))


@export
Expand Down Expand Up @@ -5097,17 +5097,17 @@ def block(arrays: ArrayLike | list[ArrayLike]) -> Array:


@overload
def atleast_1d() -> list[Array]:
def atleast_1d() -> tuple[()]:
...
@overload
def atleast_1d(x: ArrayLike, /) -> Array:
...
@overload
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]:
...
@export
@api.jit
def atleast_1d(*arys: ArrayLike) -> Array | list[Array]:
def atleast_1d(*arys: ArrayLike) -> Array | tuple[Array, ...]:
"""Convert inputs to arrays with at least 1 dimension.

JAX implementation of :func:`numpy.atleast_1d`.
Expand Down Expand Up @@ -5139,30 +5139,30 @@ def atleast_1d(*arys: ArrayLike) -> Array | list[Array]:
Array([0, 1, 2, 3], dtype=int32)

Multiple arguments can be passed to the function at once, in which
case a list of results is returned:
case a tuple of results is returned:

>>> jnp.atleast_1d(x, y)
[Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]
(Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32),)
"""
util.check_arraylike("atleast_1d", *arys, emit_warning=True)
if len(arys) == 1:
return array(arys[0], copy=False, ndmin=1)
else:
return [array(arr, copy=False, ndmin=1) for arr in arys]
return tuple(array(arr, copy=False, ndmin=1) for arr in arys)


@overload
def atleast_2d() -> list[Array]:
def atleast_2d() -> tuple[()]:
...
@overload
def atleast_2d(x: ArrayLike, /) -> Array:
...
@overload
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]:
...
@export
@api.jit
def atleast_2d(*arys: ArrayLike) -> Array | list[Array]:
def atleast_2d(*arys: ArrayLike) -> Array | tuple[Array, ...]:
"""Convert inputs to arrays with at least 2 dimensions.

JAX implementation of :func:`numpy.atleast_2d`.
Expand Down Expand Up @@ -5202,31 +5202,31 @@ def atleast_2d(*arys: ArrayLike) -> Array | list[Array]:
[1., 1., 1.]], dtype=float32)

Multiple arguments can be passed to the function at once, in which
case a list of results is returned:
case a tuple of results is returned:

>>> jnp.atleast_2d(x, y)
[Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]
(Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32),)
"""
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
util.check_arraylike("atleast_2d", *arys, emit_warning=True)
if len(arys) == 1:
return array(arys[0], copy=False, ndmin=2)
else:
return [array(arr, copy=False, ndmin=2) for arr in arys]
return tuple(array(arr, copy=False, ndmin=2) for arr in arys)


@overload
def atleast_3d() -> list[Array]:
def atleast_3d() -> tuple[()]:
...
@overload
def atleast_3d(x: ArrayLike, /) -> Array:
...
@overload
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]:
...
@export
@api.jit
def atleast_3d(*arys: ArrayLike) -> Array | list[Array]:
def atleast_3d(*arys: ArrayLike) -> Array | tuple[Array, ...]:
"""Convert inputs to arrays with at least 3 dimensions.

JAX implementation of :func:`numpy.atleast_3d`.
Expand Down Expand Up @@ -5289,7 +5289,7 @@ def atleast_3d(*arys: ArrayLike) -> Array | list[Array]:
arr = lax.expand_dims(arr, dimensions=(2,))
return arr
else:
return [atleast_3d(arr) for arr in arys]
return tuple(atleast_3d(arr) for arr in arys)


@export
Expand Down Expand Up @@ -6044,7 +6044,7 @@ def _arange_dynamic(

@export
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
indexing: str = 'xy') -> list[Array]:
indexing: str = 'xy') -> tuple[Array, ...]:
"""Construct N-dimensional grid arrays from N 1-dimensional vectors.

JAX implementation of :func:`numpy.meshgrid`.
Expand Down Expand Up @@ -6120,7 +6120,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
output = [lax.broadcast_in_dim(a, _a_shape(i, a), (i,)) for i, a, in enumerate(args)]
if indexing == "xy" and len(args) >= 2:
output[0], output[1] = output[1], output[0]
return output
return tuple(output)


@export
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/numpy/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,15 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:


@api.jit(inline=True)
def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
def _broadcast_arrays(*args: ArrayLike) -> tuple[Array, ...]:
"""Like Numpy's broadcast_arrays but doesn't return views."""
avals = [core.shaped_abstractify(arg) for arg in args]
shapes = [a.shape for a in avals]
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
return [lax.asarray(arg) for arg in args]
return tuple(lax.asarray(arg) for arg in args)
result_shape = lax.broadcast_shapes(*shapes)
result_sharding = lax.broadcast_shardings(*avals)
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]
return tuple(_broadcast_to(arg, result_shape, result_sharding) for arg in args)


def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None
Expand Down
18 changes: 9 additions & 9 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,25 @@ def atan(x: ArrayLike, /) -> Array: ...
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def atanh(x: ArrayLike, /) -> Array: ...
@overload
def atleast_1d() -> list[Array]: ...
def atleast_1d() -> tuple[()]: ...
@overload
def atleast_1d(x: ArrayLike, /) -> Array: ...
@overload
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ...
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ...

@overload
def atleast_2d() -> list[Array]: ...
def atleast_2d() -> tuple[()]: ...
@overload
def atleast_2d(x: ArrayLike, /) -> Array: ...
@overload
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ...
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ...

@overload
def atleast_3d() -> list[Array]: ...
def atleast_3d() -> tuple[()]: ...
@overload
def atleast_3d(x: ArrayLike, /) -> Array: ...
@overload
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ...
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ...

@overload
def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ...,
Expand Down Expand Up @@ -247,7 +247,7 @@ def blackman(M: int) -> Array: ...
def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ...
bool: Any
bool_: Any
def broadcast_arrays(*args: ArrayLike) -> list[Array]: ...
def broadcast_arrays(*args: ArrayLike) -> tuple[Array, ...]: ...

@overload
def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ...
Expand Down Expand Up @@ -543,7 +543,7 @@ def histogramdd(
range: Sequence[None | Array | Sequence[ArrayLike]] | None = ...,
weights: ArrayLike | None = ...,
density: builtins.bool | None = ...,
) -> tuple[Array, list[Array]]: ...
) -> tuple[Array, tuple[Array, ...]]: ...
def hsplit(
ary: ArrayLike, indices_or_sections: int | ArrayLike
) -> list[Array]: ...
Expand Down Expand Up @@ -675,7 +675,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ...,
out: None = ..., overwrite_input: builtins.bool = ...,
keepdims: builtins.bool = ...) -> Array: ...
def meshgrid(*xi: ArrayLike, copy: builtins.bool = ..., sparse: builtins.bool = ...,
indexing: str = ...) -> list[Array]: ...
indexing: str = ...) -> tuple[Array, ...]: ...
mgrid: _Mgrid
def min(a: ArrayLike, axis: _Axis = ..., out: None = ...,
keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,
Expand Down
Loading