|
18 | 18 | from ._lib._utils._helpers import asarrays |
19 | 19 | from ._lib._utils._typing import Array, DType |
20 | 20 |
|
21 | | -__all__ = ["isclose", "nan_to_num", "one_hot", "pad"] |
| 21 | +__all__ = ["expand_dims", "isclose", "nan_to_num", "one_hot", "pad"] |
| 22 | + |
| 23 | + |
| 24 | +def expand_dims( |
| 25 | + a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None |
| 26 | +) -> Array: |
| 27 | + """ |
| 28 | + Expand the shape of an array. |
| 29 | +
|
| 30 | + Insert (a) new axis/axes that will appear at the position(s) specified by |
| 31 | + `axis` in the expanded array shape. |
| 32 | +
|
| 33 | + This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*. |
| 34 | + Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays. |
| 35 | +
|
| 36 | + Parameters |
| 37 | + ---------- |
| 38 | + a : array |
| 39 | + Array to have its shape expanded. |
| 40 | + axis : int or tuple of ints, optional |
| 41 | + Position(s) in the expanded axes where the new axis (or axes) is/are placed. |
| 42 | + If multiple positions are provided, they should be unique (note that a position |
| 43 | + given by a positive index could also be referred to by a negative index - |
| 44 | + that will also result in an error). |
| 45 | + Default: ``(0,)``. |
| 46 | + xp : array_namespace, optional |
| 47 | + The standard-compatible namespace for `a`. Default: infer. |
| 48 | +
|
| 49 | + Returns |
| 50 | + ------- |
| 51 | + array |
| 52 | + `a` with an expanded shape. |
| 53 | +
|
| 54 | + Examples |
| 55 | + -------- |
| 56 | + >>> import array_api_strict as xp |
| 57 | + >>> import array_api_extra as xpx |
| 58 | + >>> x = xp.asarray([1, 2]) |
| 59 | + >>> x.shape |
| 60 | + (2,) |
| 61 | +
|
| 62 | + The following is equivalent to ``x[xp.newaxis, :]`` or ``x[xp.newaxis]``: |
| 63 | +
|
| 64 | + >>> y = xpx.expand_dims(x, axis=0, xp=xp) |
| 65 | + >>> y |
| 66 | + Array([[1, 2]], dtype=array_api_strict.int64) |
| 67 | + >>> y.shape |
| 68 | + (1, 2) |
| 69 | +
|
| 70 | + The following is equivalent to ``x[:, xp.newaxis]``: |
| 71 | +
|
| 72 | + >>> y = xpx.expand_dims(x, axis=1, xp=xp) |
| 73 | + >>> y |
| 74 | + Array([[1], |
| 75 | + [2]], dtype=array_api_strict.int64) |
| 76 | + >>> y.shape |
| 77 | + (2, 1) |
| 78 | +
|
| 79 | + ``axis`` may also be a tuple: |
| 80 | +
|
| 81 | + >>> y = xpx.expand_dims(x, axis=(0, 1), xp=xp) |
| 82 | + >>> y |
| 83 | + Array([[[1, 2]]], dtype=array_api_strict.int64) |
| 84 | +
|
| 85 | + >>> y = xpx.expand_dims(x, axis=(2, 0), xp=xp) |
| 86 | + >>> y |
| 87 | + Array([[[1], |
| 88 | + [2]]], dtype=array_api_strict.int64) |
| 89 | + """ |
| 90 | + if xp is None: |
| 91 | + xp = array_namespace(a) |
| 92 | + |
| 93 | + if not isinstance(axis, tuple): |
| 94 | + axis = (axis,) |
| 95 | + ndim = a.ndim + len(axis) |
| 96 | + if axis != () and (min(axis) < -ndim or max(axis) >= ndim): |
| 97 | + err_msg = ( |
| 98 | + f"a provided axis position is out of bounds for array of dimension {a.ndim}" |
| 99 | + ) |
| 100 | + raise IndexError(err_msg) |
| 101 | + axis = tuple(dim % ndim for dim in axis) |
| 102 | + if len(set(axis)) != len(axis): |
| 103 | + err_msg = "Duplicate dimensions specified in `axis`." |
| 104 | + raise ValueError(err_msg) |
| 105 | + |
| 106 | + if is_numpy_namespace(xp) or is_dask_namespace(xp) or is_jax_namespace(xp): |
| 107 | + return xp.expand_dims(a, axis=axis) |
| 108 | + |
| 109 | + return _funcs.expand_dims(a, axis=axis, xp=xp) |
22 | 110 |
|
23 | 111 |
|
24 | 112 | def isclose( |
|
0 commit comments