From be2d491b136b5fd896e93bef8228fad10b87449c Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 8 Jan 2026 20:15:34 +0000 Subject: [PATCH] ENH: cupy: add a workaround for cp.searchorted 2nd argument Array API 2025.12 allows python scalars for the x2 argument of `searchsorted`. CuPy only supports python scalars for x2 from CuPy 14.0. Until this is the minimum supported version, array-api-compat needs a workaround. --- array_api_compat/cupy/_aliases.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 2e512fc8..3858b9aa 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,6 +1,7 @@ from __future__ import annotations from builtins import bool as py_bool +from typing import Literal import cupy as cp @@ -139,6 +140,24 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: return cp.take_along_axis(x, indices, axis=axis) +def searchsorted( + x1: Array, + x2: Array | int | float, + /, + *, + side: Literal['left', 'right'] = 'left', + sorter: Array | None = None +) -> Array: + # Match https://github.com/cupy/cupy/pull/9512/ until cupy v14 is the minimum + # supported version + if not isinstance(x2, cp.ndarray): + if not isinstance(x2, int | float | complex): + raise NotImplementedError( + 'Only python scalars or ndarrays are supported for x2') + x2 = cp.asarray(x2) + return cp.searchsorted(x1, x2, side, sorter) + + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -161,7 +180,9 @@ def take_along_axis(x: Array, indices: Array, /, *, axis: int = -1) -> Array: 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign', - 'ceil', 'floor', 'trunc', 'take_along_axis'] + 'ceil', 'floor', 'trunc', 'take_along_axis', + 'searchsorted', +] def __dir__() -> list[str]: