From 661b5312bc369f726af8126de40304dbca849b8f Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 8 Jan 2026 20:15:34 +0000 Subject: [PATCH] ENH: cupy: add a workaround for cp.searchorted 2nd argument Array API 2025.12 allows python scalars for the x2 argument of `searchsorted`. CuPy only supports python scalars for x2 from CuPy 14.0. Until this is the minimum supported version, array-api-compat needs a workaround. --- array_api_compat/cupy/_aliases.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index f91805f2..7b7bfda6 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -149,6 +149,24 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra return tuple(cp.meshgrid(*arrays, indexing=indexing)) +# Match https://github.com/cupy/cupy/pull/9512/ until cupy v14 is the minimum +# supported version +def searchsorted( + x1: Array, + x2: Array | int | float, + /, + *, + side: Literal['left', 'right'] = 'left', + sorter: Array | None = None +) -> Array: + if not isinstance(x2, cp.ndarray): + if not isinstance(x2, int | float | complex): + raise NotImplementedError( + 'Only python scalars or ndarrays are supported for x2') + x2 = cp.asarray(x2) + return cp.searchsorted(x1, x2, side, sorter) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -172,7 +190,9 @@ def meshgrid(*arrays: Array, indexing: Literal['xy', 'ij'] = 'xy') -> tuple[Arra 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', 'ceil', 'floor', 'trunc', 'take_along_axis', - 'broadcast_arrays', 'meshgrid'] + 'broadcast_arrays', 'meshgrid', + 'searchsorted', +] def __dir__() -> list[str]: