Skip to content

Commit b79e00a

Browse files
authored
Merge pull request #676 from numpy/perf/linalg._linalg
2 parents 3ae31eb + d0960f1 commit b79e00a

1 file changed

Lines changed: 54 additions & 69 deletions

File tree

src/numpy-stubs/linalg/_linalg.pyi

Lines changed: 54 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from collections.abc import Iterable, Sequence
2-
from typing import Any, Generic, Literal as L, NamedTuple, SupportsIndex as CanIndex, SupportsInt, TypeAlias, overload
2+
from typing import Any, Generic, Literal as L, NamedTuple, SupportsIndex, SupportsInt, TypeAlias, overload
33
from typing_extensions import TypeVar
44

55
import _numtype as _nt
66
import numpy as np
77
from numpy._core.fromnumeric import matrix_transpose
8-
from numpy._core.numeric import vecdot
8+
from numpy._core.umath import vecdot
99
from numpy._globals import _NoValueType
1010
from numpy._typing import DTypeLike, _DTypeLike as _ToDType
1111

@@ -65,27 +65,6 @@ _InexactT_co = TypeVar("_InexactT_co", bound=np.inexact, default=Any, covariant=
6565
_FloatingNDT_co = TypeVar("_FloatingNDT_co", bound=np.floating | _nt.Array[np.floating], default=Any, covariant=True)
6666
_InexactNDT_co = TypeVar("_InexactNDT_co", bound=np.inexact | _nt.Array[np.inexact], default=Any, covariant=True)
6767

68-
_AnyNumberT = TypeVar(
69-
"_AnyNumberT",
70-
np.int8,
71-
np.int16,
72-
np.int32,
73-
np.int64,
74-
np.long,
75-
np.ulong,
76-
np.uint8,
77-
np.uint16,
78-
np.uint32,
79-
np.uint64,
80-
np.float16,
81-
np.float32,
82-
np.float64,
83-
np.longdouble,
84-
np.complex64,
85-
np.complex128,
86-
np.clongdouble,
87-
)
88-
8968
###
9069

9170
_Option: TypeAlias = _T | _NoValueType
@@ -94,7 +73,7 @@ _False: TypeAlias = L[False]
9473
_True: TypeAlias = L[True]
9574

9675
_Tuple2: TypeAlias = tuple[_T, _T]
97-
_ToInt: TypeAlias = SupportsInt | CanIndex
76+
_ToInt: TypeAlias = SupportsInt | SupportsIndex
9877

9978
_Ax2: TypeAlias = _ToInt | _Tuple2[_ToInt]
10079
_Axes: TypeAlias = Iterable[int]
@@ -320,9 +299,11 @@ _NegInt: TypeAlias = L[-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -
320299

321300
#
322301
@overload # workaround for microsoft/pyright#10232
323-
def matrix_power(a: _nt.CastsArray[np.float64, _nt.NeitherShape], n: CanIndex) -> _Array2ND[np.float64]: ...
302+
def matrix_power(a: _nt.CastsArray[np.float64, _nt.NeitherShape], n: SupportsIndex) -> _Array2ND[np.float64]: ...
324303
@overload # workaround for microsoft/pyright#10232
325-
def matrix_power(a: _nt.CastsWithArray[np.float64, _NumberT, _nt.NeitherShape], n: CanIndex) -> _Array2ND[_NumberT]: ...
304+
def matrix_power(
305+
a: _nt.CastsWithArray[np.float64, _NumberT, _nt.NeitherShape], n: SupportsIndex
306+
) -> _Array2ND[_NumberT]: ...
326307
@overload
327308
def matrix_power(a: _nt.CanLenArray[_NumberT, _Shape2NDT], n: _PosInt) -> _nt.Array[_NumberT, _Shape2NDT]: ...
328309
@overload
@@ -332,19 +313,19 @@ def matrix_power(a: _nt.ToInt_1nd, n: _PosInt) -> _Array2ND[np.intp]: ...
332313
@overload
333314
def matrix_power(a: _nt.CoInteger_1nd, n: _NegInt) -> _Array2ND[np.float64]: ...
334315
@overload
335-
def matrix_power(a: _nt.ToFloat64_1nd, n: CanIndex) -> _Array2ND[np.float64]: ...
316+
def matrix_power(a: _nt.ToFloat64_1nd, n: SupportsIndex) -> _Array2ND[np.float64]: ...
336317
@overload
337-
def matrix_power(a: _nt.ToComplex128_1nd, n: CanIndex) -> _Array2ND[np.complex128]: ...
318+
def matrix_power(a: _nt.ToComplex128_1nd, n: SupportsIndex) -> _Array2ND[np.complex128]: ...
338319
@overload
339-
def matrix_power(a: _nt._ToArray_1nd[_Inexact32T], n: CanIndex) -> _Array2ND[_Inexact32T]: ...
320+
def matrix_power(a: _nt._ToArray_1nd[_Inexact32T], n: SupportsIndex) -> _Array2ND[_Inexact32T]: ...
340321
@overload
341-
def matrix_power(a: _nt.ToObject_1nd, n: CanIndex) -> _Array2ND[np.object_]: ...
322+
def matrix_power(a: _nt.ToObject_1nd, n: SupportsIndex) -> _Array2ND[np.object_]: ...
342323
@overload
343324
def matrix_power(a: _nt.ToUInteger_1nd, n: _PosInt) -> _Array2ND[np.unsignedinteger]: ...
344325
@overload
345326
def matrix_power(a: _nt.ToInteger_1nd, n: _PosInt) -> _Array2ND[np.integer]: ...
346327
@overload
347-
def matrix_power(a: _nt.CoComplex_1nd | _nt.ToObject_1nd, n: CanIndex) -> _Array2ND[Any]: ...
328+
def matrix_power(a: _nt.CoComplex_1nd | _nt.ToObject_1nd, n: SupportsIndex) -> _Array2ND[Any]: ...
348329

349330
#
350331
@overload
@@ -394,9 +375,9 @@ def outer(x1: _nt.ToNumber_1d, x2: _nt.ToNumber_1d, /) -> _nt.Array2D[Any]: ...
394375
@overload # workaround for microsoft/pyright#10232
395376
def multi_dot(arrays: Iterable[_nt._ToArray_nnd[_nt.co_number]], *, out: None = None) -> Any: ...
396377
@overload
397-
def multi_dot(arrays: Iterable[_nt._ToArray_1ds[_AnyNumberT]], *, out: None = None) -> _AnyNumberT: ...
378+
def multi_dot(arrays: Iterable[_nt._ToArray_1ds[_NumberT]], *, out: None = None) -> _NumberT: ...
398379
@overload
399-
def multi_dot(arrays: Iterable[_nt._ToArray_2nd[_AnyNumberT]], *, out: None = None) -> _nt.Array[_AnyNumberT]: ...
380+
def multi_dot(arrays: Iterable[_nt._ToArray_2nd[_NumberT]], *, out: None = None) -> _nt.Array[_NumberT]: ...
400381
@overload
401382
def multi_dot(arrays: Iterable[Sequence[bool]], *, out: None = None) -> np.bool: ...
402383
@overload
@@ -420,11 +401,7 @@ def multi_dot(
420401
arrays: Iterable[_nt.CoComplex_1nd | _nt.ToTimeDelta_1nd | _nt.ToObject_1nd], *, out: None = None
421402
) -> Any: ...
422403

423-
# pyright false positive in case of typevar constraints
424-
@overload
425-
def cross( # pyright: ignore[reportOverlappingOverload]
426-
x1: _nt._ToArray_1nd[_AnyNumberT], x2: _nt._ToArray_1nd[_AnyNumberT], /, *, axis: int = -1
427-
) -> _nt.Array[_AnyNumberT]: ...
404+
#
428405
@overload
429406
def cross(x1: _nt.ToBool_1nd, x2: _nt.ToBool_1nd, /, *, axis: int = -1) -> _nt.Array[np.bool]: ...
430407
@overload
@@ -440,6 +417,10 @@ def cross(x1: _nt.ToComplex128_1nd, x2: _nt.CoComplex128_1nd, /, *, axis: int =
440417
@overload
441418
def cross(x1: _nt.CoComplex128_1nd, x2: _nt.ToComplex128_1nd, /, *, axis: int = -1) -> _nt.Array[np.complex128]: ...
442419
@overload
420+
def cross(
421+
x1: _nt._ToArray_1nd[_NumberT], x2: _nt._ToArray_1nd[_NumberT], /, *, axis: int = -1
422+
) -> _nt.Array[_NumberT]: ...
423+
@overload
443424
def cross(x1: _nt.ToInteger_1nd, x2: _nt.CoInteger_1nd, /, *, axis: int = -1) -> _nt.Array[np.integer]: ...
444425
@overload
445426
def cross(x1: _nt.CoInteger_1nd, x2: _nt.ToInteger_1nd, /, *, axis: int = -1) -> _nt.Array[np.integer]: ...
@@ -454,16 +435,10 @@ def cross(x1: _nt.CoComplex_1nd, x2: _nt.ToComplex_1nd, /, *, axis: int = -1) ->
454435
@overload
455436
def cross(x1: _nt.CoComplex_1nd, x2: _nt.CoComplex_1nd, /, *, axis: int = -1) -> _nt.Array[Any]: ...
456437

457-
# pyright false positive in case of typevar constraints
438+
#
458439
@overload # workaround for microsoft/pyright#10232
459440
def matmul(x1: _nt._ToArray_nnd[_nt.co_number], x2: _nt._ToArray_nnd[_nt.co_number], /) -> Any: ...
460441
@overload
461-
def matmul(x1: _nt._ToArray_1ds[_AnyNumberT], x2: _nt._ToArray_1ds[_AnyNumberT], /) -> _AnyNumberT: ... # pyright: ignore[reportOverlappingOverload]
462-
@overload
463-
def matmul(x1: _nt._ToArray_2nd[_AnyNumberT], x2: _nt._ToArray_1nd[_AnyNumberT], /) -> _nt.Array[_AnyNumberT]: ... # pyright: ignore[reportOverlappingOverload]
464-
@overload
465-
def matmul(x1: _nt._ToArray_1nd[_AnyNumberT], x2: _nt._ToArray_2nd[_AnyNumberT], /) -> _nt.Array[_AnyNumberT]: ... # pyright: ignore[reportOverlappingOverload]
466-
@overload
467442
def matmul(x1: _nt.ToBool_1ds, x2: _nt.ToBool_1ds, /) -> np.bool: ...
468443
@overload
469444
def matmul(x1: _nt.ToBool_2nd, x2: _nt.ToBool_1nd, /) -> _nt.Array[np.bool]: ...
@@ -494,6 +469,12 @@ def matmul(x1: _nt.ToComplex128_2nd, x2: _nt.CoComplex128_1nd, /) -> _nt.Array[n
494469
@overload
495470
def matmul(x1: _nt.CoComplex128_1nd, x2: _nt.ToComplex128_2nd, /) -> _nt.Array[np.complex128]: ...
496471
@overload
472+
def matmul(x1: _nt._ToArray_1ds[_NumberT], x2: _nt._ToArray_1ds[_NumberT], /) -> _NumberT: ...
473+
@overload
474+
def matmul(x1: _nt._ToArray_2nd[_NumberT], x2: _nt._ToArray_1nd[_NumberT], /) -> _nt.Array[_NumberT]: ...
475+
@overload
476+
def matmul(x1: _nt._ToArray_1nd[_NumberT], x2: _nt._ToArray_2nd[_NumberT], /) -> _nt.Array[_NumberT]: ...
477+
@overload
497478
def matmul(x1: _nt.ToInteger_1ds, x2: _nt.CoInteger_1ds, /) -> np.integer: ...
498479
@overload
499480
def matmul(x1: _nt.CoInteger_1ds, x2: _nt.ToInteger_1ds, /) -> np.integer: ...
@@ -942,56 +923,60 @@ def vector_norm(
942923

943924
#
944925
@overload
945-
def diagonal(x: _nt.ToObject_2nd, /, *, offset: CanIndex = 0) -> _nt.Array[np.object_]: ...
926+
def diagonal(x: _nt.ToObject_2nd, /, *, offset: SupportsIndex = 0) -> _nt.Array[np.object_]: ...
946927
@overload
947-
def diagonal(x: _nt._ToArray_2ds[_NativeScalarT], /, *, offset: CanIndex = 0) -> _nt.Array1D[_NativeScalarT]: ...
928+
def diagonal(x: _nt._ToArray_2ds[_NativeScalarT], /, *, offset: SupportsIndex = 0) -> _nt.Array1D[_NativeScalarT]: ...
948929
@overload
949-
def diagonal(x: _ToArray_2nd_ish[_NativeScalarT], /, *, offset: CanIndex = 0) -> _nt.Array[_NativeScalarT]: ...
930+
def diagonal(x: _ToArray_2nd_ish[_NativeScalarT], /, *, offset: SupportsIndex = 0) -> _nt.Array[_NativeScalarT]: ...
950931
@overload
951-
def diagonal(x: _nt.Sequence2ND[bool], /, *, offset: CanIndex = 0) -> _nt.Array[np.bool]: ...
932+
def diagonal(x: _nt.Sequence2ND[bool], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.bool]: ...
952933
@overload
953-
def diagonal(x: _nt.Sequence2ND[_nt.JustInt], /, *, offset: CanIndex = 0) -> _nt.Array[np.intp]: ...
934+
def diagonal(x: _nt.Sequence2ND[_nt.JustInt], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.intp]: ...
954935
@overload
955-
def diagonal(x: _nt.Sequence2ND[_nt.JustFloat], /, *, offset: CanIndex = 0) -> _nt.Array[np.float64]: ...
936+
def diagonal(x: _nt.Sequence2ND[_nt.JustFloat], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.float64]: ...
956937
@overload
957-
def diagonal(x: _nt.Sequence2ND[_nt.JustComplex], /, *, offset: CanIndex = 0) -> _nt.Array[np.complex128]: ...
938+
def diagonal(x: _nt.Sequence2ND[_nt.JustComplex], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.complex128]: ...
958939
@overload
959-
def diagonal(x: _nt.Sequence2ND[_nt.JustBytes], /, *, offset: CanIndex = 0) -> _nt.Array[np.bytes_]: ...
940+
def diagonal(x: _nt.Sequence2ND[_nt.JustBytes], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.bytes_]: ...
960941
@overload
961-
def diagonal(x: _nt.Sequence2ND[_nt.JustStr], /, *, offset: CanIndex = 0) -> _nt.Array[np.str_]: ...
942+
def diagonal(x: _nt.Sequence2ND[_nt.JustStr], /, *, offset: SupportsIndex = 0) -> _nt.Array[np.str_]: ...
962943
@overload
963-
def diagonal(x: _nt.ToGeneric_1nd, /, *, offset: CanIndex = 0) -> _nt.Array[Any]: ...
944+
def diagonal(x: _nt.ToGeneric_1nd, /, *, offset: SupportsIndex = 0) -> _nt.Array[Any]: ...
964945

965946
#
966947
@overload
967-
def trace(x: _nt._ToArray_2ds[_ScalarT], /, *, offset: CanIndex = 0, dtype: None = None) -> _ScalarT: ...
948+
def trace(x: _nt._ToArray_2ds[_ScalarT], /, *, offset: SupportsIndex = 0, dtype: None = None) -> _ScalarT: ...
968949
@overload
969-
def trace(x: _nt._ToArray_3nd[_ScalarT], /, *, offset: CanIndex = 0, dtype: None = None) -> _nt.Array[_ScalarT]: ...
950+
def trace(
951+
x: _nt._ToArray_3nd[_ScalarT], /, *, offset: SupportsIndex = 0, dtype: None = None
952+
) -> _nt.Array[_ScalarT]: ...
970953
@overload
971-
def trace(x: _nt.Sequence2D[bool], /, *, offset: CanIndex = 0, dtype: None = None) -> np.bool: ...
954+
def trace(x: _nt.Sequence2D[bool], /, *, offset: SupportsIndex = 0, dtype: None = None) -> np.bool: ...
972955
@overload
973-
def trace(x: _nt.Sequence3ND[bool], /, *, offset: CanIndex = 0, dtype: None = None) -> _nt.Array[np.bool]: ...
956+
def trace(x: _nt.Sequence3ND[bool], /, *, offset: SupportsIndex = 0, dtype: None = None) -> _nt.Array[np.bool]: ...
974957
@overload
975-
def trace(x: _nt.Sequence2D[_nt.JustInt], /, *, offset: CanIndex = 0, dtype: None = None) -> np.intp: ...
958+
def trace(x: _nt.Sequence2D[_nt.JustInt], /, *, offset: SupportsIndex = 0, dtype: None = None) -> np.intp: ...
976959
@overload
977-
def trace(x: _nt.Sequence3ND[_nt.JustInt], /, *, offset: CanIndex = 0, dtype: None = None) -> _nt.Array[np.intp]: ...
960+
def trace(
961+
x: _nt.Sequence3ND[_nt.JustInt], /, *, offset: SupportsIndex = 0, dtype: None = None
962+
) -> _nt.Array[np.intp]: ...
978963
@overload
979-
def trace(x: _nt.Sequence2D[_nt.JustFloat], /, *, offset: CanIndex = 0, dtype: None = None) -> np.float64: ...
964+
def trace(x: _nt.Sequence2D[_nt.JustFloat], /, *, offset: SupportsIndex = 0, dtype: None = None) -> np.float64: ...
980965
@overload
981966
def trace(
982-
x: _nt.Sequence3ND[_nt.JustFloat], /, *, offset: CanIndex = 0, dtype: None = None
967+
x: _nt.Sequence3ND[_nt.JustFloat], /, *, offset: SupportsIndex = 0, dtype: None = None
983968
) -> _nt.Array[np.float64]: ...
984969
@overload
985-
def trace(x: _nt.Sequence2D[_nt.JustComplex], /, *, offset: CanIndex = 0, dtype: None = None) -> np.complex128: ...
970+
def trace(x: _nt.Sequence2D[_nt.JustComplex], /, *, offset: SupportsIndex = 0, dtype: None = None) -> np.complex128: ...
986971
@overload
987972
def trace(
988-
x: _nt.Sequence3ND[_nt.JustComplex], /, *, offset: CanIndex = 0, dtype: None = None
973+
x: _nt.Sequence3ND[_nt.JustComplex], /, *, offset: SupportsIndex = 0, dtype: None = None
989974
) -> _nt.Array[np.complex128]: ...
990975
@overload
991-
def trace(x: _nt.CoComplex_2ds, /, *, offset: CanIndex = 0, dtype: _ToDType[_ScalarT]) -> _ScalarT: ...
976+
def trace(x: _nt.CoComplex_2ds, /, *, offset: SupportsIndex = 0, dtype: _ToDType[_ScalarT]) -> _ScalarT: ...
992977
@overload
993-
def trace(x: _nt.CoComplex_3nd, /, *, offset: CanIndex = 0, dtype: _ToDType[_ScalarT]) -> _nt.Array[_ScalarT]: ...
978+
def trace(x: _nt.CoComplex_3nd, /, *, offset: SupportsIndex = 0, dtype: _ToDType[_ScalarT]) -> _nt.Array[_ScalarT]: ...
994979
@overload
995-
def trace(x: _nt.CoComplex_3nd, /, *, offset: CanIndex = 0, dtype: DTypeLike | None = None) -> _nt.Array[Any]: ...
980+
def trace(x: _nt.CoComplex_3nd, /, *, offset: SupportsIndex = 0, dtype: DTypeLike | None = None) -> _nt.Array[Any]: ...
996981
@overload
997-
def trace(x: _nt.CoComplex_1nd, /, *, offset: CanIndex = 0, dtype: DTypeLike | None = None) -> Any: ...
982+
def trace(x: _nt.CoComplex_1nd, /, *, offset: SupportsIndex = 0, dtype: DTypeLike | None = None) -> Any: ...

0 commit comments

Comments
 (0)