Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 40 additions & 27 deletions src/_numtype/_nep50.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Interface to the NEP 50 "safe" promotion rules that are embedded within the numeric
# scalar types as the type-check-only `__promote__` method.
# Interface to the NEP 50 "safe" promotion rules that are associated with the numeric
# scalar types as the type-check-only `__nep50*__` methods.
# https://numpy.org/neps/nep-0050-scalar-promotion.html

from typing import Any, Protocol, TypeAlias, type_check_only
Expand Down Expand Up @@ -33,11 +33,11 @@ _BuitinT = TypeVar("_BuitinT")
_BuitinT_co = TypeVar("_BuitinT_co", covariant=True)

_ScalarIn: TypeAlias = np.generic | str
_ScalarInT = TypeVar("_ScalarInT", bound=_ScalarIn)
_OtherScalarT = TypeVar("_OtherScalarT", bound=_ScalarIn)
_ScalarInT_contra = TypeVar("_ScalarInT_contra", bound=_ScalarIn, contravariant=True)

_ScalarOut: TypeAlias = np.generic
_ScalarOutT = TypeVar("_ScalarOutT", bound=_ScalarOut, default=Any)
_IntoScalarT = TypeVar("_IntoScalarT", bound=_ScalarOut, default=Any)
_ScalarOutT_co = TypeVar("_ScalarOutT_co", bound=_ScalarOut, covariant=True)
_ScalarOutT_contra = TypeVar("_ScalarOutT_contra", bound=_ScalarOut, contravariant=True)

Expand All @@ -48,7 +48,10 @@ _ShapeT_co = TypeVar("_ShapeT_co", bound=_shape.Shape, covariant=True)

@type_check_only
class _CanNEP50(Protocol[_ScalarOutT_contra, _ScalarInT_contra, _ScalarOutT_co]):
def __nep50__(self, below: _ScalarOutT_contra, above: _ScalarInT_contra, /) -> _ScalarOutT_co: ...
def __nep50__(self, into: _ScalarOutT_contra, from_: _ScalarInT_contra, /) -> _ScalarOutT_co: ...

# Due to limitations in mypy/pyright, we cannot combine these individual `__nep50_rule__` into a single
# overloaded method.

@type_check_only
class _CanNEP50Rule0(Protocol[_ScalarInT_contra, _ScalarOutT_co]):
Expand Down Expand Up @@ -131,42 +134,52 @@ class _LikeScalar(Protocol[_LikeT_co]):

_SequenceND: TypeAlias = _LikeT | _NestedSequence[_LikeT]

# Accepts anything with `.shape` and `.dtype` if its `dtype.type` scalar-type can be safe-cast into `_IntoScalarT`.
# E.g. `Casts[np.int16]` will accept arrays or scalars of `np.uint8` and `np.bool`, but not `np.uint16` or `np.float32`.
# An optional second `_ShapeT` type-parameter can be used to further restrict the rank (shape-type).
Casts = TypeAliasType(
"Casts", _SequenceND[_LikeNumeric[_CanNEP50[_ScalarOutT, Any, Any], _ShapeT]], type_params=(_ScalarOutT, _ShapeT)
"Casts", _SequenceND[_LikeNumeric[_CanNEP50[_IntoScalarT, Any, Any], _ShapeT]], type_params=(_IntoScalarT, _ShapeT)
)
# Same as `Casts`, but only for array-like types, rejecting "bare" scalars like `np.float64`.
CastsArray = TypeAliasType(
"CastsArray", _SequenceND[_LikeArray[_CanNEP50[_ScalarOutT, Any, Any], _ShapeT]], type_params=(_ScalarOutT, _ShapeT)
"CastsArray",
_SequenceND[_LikeArray[_CanNEP50[_IntoScalarT, Any, Any], _ShapeT]],
type_params=(_IntoScalarT, _ShapeT),
)
CastsScalar = TypeAliasType("CastsScalar", _LikeScalar[_CanNEP50[_ScalarOutT, Any, Any]], type_params=(_ScalarOutT,))
# Same as `Casts`, but only for scalar-like types, rejecting array-like types, including zero-dimensional arrays.
CastsScalar = TypeAliasType("CastsScalar", _LikeScalar[_CanNEP50[_IntoScalarT, Any, Any]], type_params=(_IntoScalarT,))

#
_CastWith: TypeAlias = (
_CanNEP50[Any, _ScalarInT, _ScalarOutT]
| _CanNEP50Rule0[_ScalarInT, _ScalarOutT]
| _CanNEP50Rule1[_ScalarInT, _ScalarOutT]
| _CanNEP50Rule2[_ScalarInT, _ScalarOutT]
| _CanNEP50Rule3[_ScalarInT, _ScalarOutT]
| _CanNEP50Rule4[_ScalarInT, _ScalarOutT]
| _CanNEP50Rule5[_ScalarInT, _ScalarOutT]
| _CanNEP50Rule6[_ScalarInT, _ScalarOutT]
_CanNEP50[Any, _OtherScalarT, _IntoScalarT]
| _CanNEP50Rule0[_OtherScalarT, _IntoScalarT]
| _CanNEP50Rule1[_OtherScalarT, _IntoScalarT]
| _CanNEP50Rule2[_OtherScalarT, _IntoScalarT]
| _CanNEP50Rule3[_OtherScalarT, _IntoScalarT]
| _CanNEP50Rule4[_OtherScalarT, _IntoScalarT]
| _CanNEP50Rule5[_OtherScalarT, _IntoScalarT]
| _CanNEP50Rule6[_OtherScalarT, _IntoScalarT]
)
# Accepts anything with `.shape` and `.dtype` if its `dtype.type` scalar-type can be safe-cast into `_IntoScalarT`
# together with `_OtherScalarT` according to any of the NEP 50 rules. Accepts both array-likes and scalars.
CastsWith = TypeAliasType(
"CastsWith",
_SequenceND[_LikeNumeric[_CastWith[_ScalarInT, _ScalarOutT], _ShapeT]],
type_params=(_ScalarInT, _ScalarOutT, _ShapeT),
_SequenceND[_LikeNumeric[_CastWith[_OtherScalarT, _IntoScalarT], _ShapeT]],
type_params=(_OtherScalarT, _IntoScalarT, _ShapeT),
)
# Same as `CastsWith`, but only for array-like types, rejecting "bare" scalars like `np.float64`.
CastsWithArray = TypeAliasType(
"CastsWithArray",
_SequenceND[_LikeArray[_CastWith[_ScalarInT, _ScalarOutT], _ShapeT]],
type_params=(_ScalarInT, _ScalarOutT, _ShapeT),
_SequenceND[_LikeArray[_CastWith[_OtherScalarT, _IntoScalarT], _ShapeT]],
type_params=(_OtherScalarT, _IntoScalarT, _ShapeT),
)
# Same as `CastsWith`, but only for scalar-like types, rejecting array-like types, including zero-dimensional arrays.
CastsWithScalar = TypeAliasType(
"CastsWithScalar", _LikeScalar[_CastWith[_ScalarInT, _ScalarOutT]], type_params=(_ScalarInT, _ScalarOutT)
"CastsWithScalar", _LikeScalar[_CastWith[_OtherScalarT, _IntoScalarT]], type_params=(_OtherScalarT, _IntoScalarT)
)

#
CastsWithBuiltin: TypeAlias = _LikeNumeric[_CanNEP50Builtin[_BuitinT, _ScalarOutT], _ShapeT]
CastsWithBool: TypeAlias = _LikeNumeric[_CanNEP50Bool[_ScalarOutT], _ShapeT]
CastsWithInt: TypeAlias = _LikeNumeric[_CanNEP50Int[_ScalarOutT], _ShapeT]
CastsWithFloat: TypeAlias = _LikeNumeric[_CanNEP50Float[_ScalarOutT], _ShapeT]
CastsWithComplex: TypeAlias = _LikeNumeric[_CanNEP50Complex[_ScalarOutT], _ShapeT]
CastsWithBuiltin: TypeAlias = _LikeNumeric[_CanNEP50Builtin[_BuitinT, _IntoScalarT], _ShapeT]
CastsWithBool: TypeAlias = _LikeNumeric[_CanNEP50Bool[_IntoScalarT], _ShapeT]
CastsWithInt: TypeAlias = _LikeNumeric[_CanNEP50Int[_IntoScalarT], _ShapeT]
CastsWithFloat: TypeAlias = _LikeNumeric[_CanNEP50Float[_IntoScalarT], _ShapeT]
CastsWithComplex: TypeAlias = _LikeNumeric[_CanNEP50Complex[_IntoScalarT], _ShapeT]
62 changes: 30 additions & 32 deletions src/numpy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3532,7 +3532,7 @@ class bool_(generic[_BoolItemT_co], Generic[_BoolItemT_co]):

#
@type_check_only
def __nep50__(self, below: _nt.co_number | timedelta64, above: Never, /) -> bool_: ...
def __nep50__(self, into: _nt.co_number | timedelta64, from_: Never, /) -> bool_: ...
@type_check_only
def __nep50_builtin__(self, /) -> tuple[py_bool, bool_]: ...
@type_check_only
Expand Down Expand Up @@ -4105,7 +4105,7 @@ class number(_CmpOpMixin[_nt.CoComplex_0d, _nt.CoComplex_1nd], generic[_NumberIt
class integer(_IntegralMixin, _RoundMixin, number[int]):
@type_check_only
def __nep50__(
self, below: timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, above: bool_, /
self, into: timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, from_: bool_, /
) -> integer: ...
@final
@override
Expand Down Expand Up @@ -4255,7 +4255,7 @@ class signedinteger(integer):
@type_check_only
@override
def __nep50__(
self, below: int64 | timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, above: bool_, /
self, into: int64 | timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, from_: bool_, /
) -> signedinteger: ...
@type_check_only
def __nep50_rule0__(self, other: uint32, /) -> int64: ...
Expand All @@ -4279,9 +4279,7 @@ class int8(_IntMixin[L[1]], signedinteger):
#
@override
@type_check_only
def __nep50__(
self, below: signedinteger | timedelta64 | inexact | _JustFloating | _JustInexact, above: bool_, /
) -> int8: ...
def __nep50__(self, into: signedinteger | timedelta64 | inexact | _JustFloating, from_: bool_, /) -> int8: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
@type_check_only
def __nep50_rule2__(self, other: uint8, /) -> int16: ...
@type_check_only
Expand All @@ -4307,8 +4305,8 @@ class int16(_IntMixin[L[2]], signedinteger):
@type_check_only
def __nep50__(
self,
below: _I16_min | timedelta64 | _F32_min | _JustFloating | complexfloating | _JustInexact,
above: _nt.co_integer8,
into: _I16_min | timedelta64 | _F32_min | _JustFloating | complexfloating | _JustInexact,
from_: _nt.co_integer8,
/,
) -> int16: ...
@type_check_only
Expand All @@ -4335,7 +4333,7 @@ class int32(_IntMixin[L[4]], signedinteger):
@override
@type_check_only
def __nep50__(
self, below: _I32_min | timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, above: _nt.co_integer16, /
self, into: _I32_min | timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, from_: _nt.co_integer16, /
) -> int32: ...
@override
@type_check_only
Expand Down Expand Up @@ -4364,7 +4362,7 @@ class int64(_IntMixin[L[8]], signedinteger):
@override
@type_check_only
def __nep50__(
self, below: int64 | timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, above: _nt.co_integer32, /
self, into: int64 | timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, from_: _nt.co_integer32, /
) -> int64: ...
@override
@type_check_only
Expand Down Expand Up @@ -4399,7 +4397,7 @@ class unsignedinteger(integer):
@type_check_only
@override
def __nep50__(
self, below: uint64 | timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, above: bool_, /
self, into: uint64 | timedelta64 | _Inexact64_min | _JustFloating | _JustInexact, from_: bool_, /
) -> unsignedinteger: ...
@type_check_only
def __nep50_rule3__(self, other: _JustUnsignedInteger, /) -> unsignedinteger: ...
Expand All @@ -4415,8 +4413,8 @@ class uint8(_IntMixin[L[1]], unsignedinteger):
#
@override
@type_check_only
def __nep50__(
self, below: _I16_min | unsignedinteger | timedelta64 | _JustFloating | inexact | _JustInexact, above: bool_, /
def __nep50__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
self, into: _I16_min | unsignedinteger | timedelta64 | _JustFloating | inexact, from_: bool_, /
) -> uint8: ...
@type_check_only
def __nep50_rule0__(self, other: int8, /) -> int16: ...
Expand Down Expand Up @@ -4447,8 +4445,8 @@ class uint16(_IntMixin[L[2]], unsignedinteger):
@type_check_only
def __nep50__(
self,
below: uint16 | _Integer32_min | timedelta64 | _F32_min | _JustFloating | complexfloating | _JustInexact,
above: _nt.co_uint8,
into: uint16 | _Integer32_min | timedelta64 | _F32_min | _JustFloating | complexfloating | _JustInexact,
from_: _nt.co_uint8,
/,
) -> uint16: ...
@type_check_only
Expand Down Expand Up @@ -4481,7 +4479,7 @@ class uint32(_IntMixin[L[4]], unsignedinteger):
@override
@type_check_only
def __nep50__(
self, below: uint32 | _nt.integer64 | timedelta64 | _Inexact64_min | _AbstractInexact, above: _nt.co_uint16, /
self, into: uint32 | _nt.integer64 | timedelta64 | _Inexact64_min | _AbstractInexact, from_: _nt.co_uint16, /
) -> uint32: ...
@type_check_only
def __nep50_rule1__(self, other: float16 | float32, /) -> float64: ...
Expand Down Expand Up @@ -4513,7 +4511,7 @@ class uint64(_IntMixin[L[8]], unsignedinteger):
@override
@type_check_only
def __nep50__(
self, below: uint64 | timedelta64 | _Inexact64_min | _AbstractInexact, above: _nt.co_uint32, /
self, into: uint64 | timedelta64 | _Inexact64_min | _AbstractInexact, from_: _nt.co_uint32, /
) -> uint64: ...
@type_check_only
def __nep50_rule2__(self, other: complex64, /) -> complex128: ...
Expand Down Expand Up @@ -4544,7 +4542,7 @@ uint = uintp

class inexact(number[_InexactItemT_co], Generic[_InexactItemT_co]):
@type_check_only
def __nep50__(self, below: clongdouble, above: _nt.co_integer8, /) -> inexact: ...
def __nep50__(self, into: clongdouble, from_: _nt.co_integer8, /) -> inexact: ...
@final
@override
@type_check_only
Expand Down Expand Up @@ -4582,7 +4580,7 @@ class inexact(number[_InexactItemT_co], Generic[_InexactItemT_co]):
class floating(_RealMixin, _RoundMixin, inexact[float]):
@override
@type_check_only
def __nep50__(self, below: _nt.inexact64l, above: _nt.co_integer8, /) -> floating: ...
def __nep50__(self, into: _nt.inexact64l, from_: _nt.co_integer8, /) -> floating: ...
@override
@type_check_only
def __nep50_rule3__(self, other: _JustFloating, /) -> floating: ...
Expand Down Expand Up @@ -4633,7 +4631,7 @@ class float16(_FloatMixin[L[2]], floating):
#
@override
@type_check_only
def __nep50__(self, below: inexact, above: _nt.co_integer8, /) -> float16: ...
def __nep50__(self, into: inexact, from_: _nt.co_integer8, /) -> float16: ...
@override
@type_check_only
def __nep50_complex__(self, /) -> complex64: ...
Expand Down Expand Up @@ -4662,7 +4660,7 @@ class float32(_FloatMixin[L[4]], floating):
#
@override
@type_check_only
def __nep50__(self, below: _F32_min | complexfloating, above: float16 | _nt.co_integer16, /) -> float32: ...
def __nep50__(self, into: _F32_min | complexfloating, from_: float16 | _nt.co_integer16, /) -> float32: ...
@override
@type_check_only
def __nep50_complex__(self, /) -> complex64: ...
Expand All @@ -4689,7 +4687,7 @@ class float64(_FloatMixin[L[8]], floating, float): # type: ignore[misc]
#
@override
@type_check_only
def __nep50__(self, below: _Inexact64_min, above: _F32_max | _nt.co_integer, /) -> float64: ...
def __nep50__(self, into: _Inexact64_min, from_: _F32_max | _nt.co_integer, /) -> float64: ...
@override
@type_check_only
def __nep50_complex__(self, /) -> complex128: ...
Expand Down Expand Up @@ -4730,7 +4728,7 @@ class longdouble(_FloatMixin[L[12, 16]], floating):
#
@override
@type_check_only
def __nep50__(self, below: longdouble | clongdouble, above: _nt.co_float64, /) -> longdouble: ...
def __nep50__(self, into: longdouble | clongdouble, from_: _nt.co_float64, /) -> longdouble: ...
@override
@type_check_only
def __nep50_complex__(self, /) -> clongdouble: ...
Expand Down Expand Up @@ -4766,7 +4764,7 @@ float128 = longdouble
class complexfloating(inexact[complex]):
@override
@type_check_only
def __nep50__(self, below: clongdouble, above: _F32_max | _nt.co_integer16, /) -> complexfloating: ...
def __nep50__(self, into: clongdouble, from_: _F32_max | _nt.co_integer16, /) -> complexfloating: ...
@final
@override
@type_check_only
Expand Down Expand Up @@ -4804,7 +4802,7 @@ class complex64(complexfloating):
#
@override
@type_check_only
def __nep50__(self, below: complexfloating, above: _F32_max | _nt.co_integer16, /) -> complex64: ...
def __nep50__(self, into: complexfloating, from_: _F32_max | _nt.co_integer16, /) -> complex64: ...
@type_check_only
def __nep50_rule0__(self, other: _nt.integer32 | _nt.integer64 | float64, /) -> complex128: ...
@type_check_only
Expand Down Expand Up @@ -4847,7 +4845,7 @@ class complex128(complexfloating, complex):
#
@override
@type_check_only
def __nep50__(self, below: _C128_min, above: complex64 | _F64_max | _nt.co_integer, /) -> complex128: ...
def __nep50__(self, into: _C128_min, from_: complex64 | _F64_max | _nt.co_integer, /) -> complex128: ...
@type_check_only
def __nep50_rule1__(self, other: longdouble, /) -> clongdouble: ...
@override
Expand Down Expand Up @@ -4893,7 +4891,7 @@ class clongdouble(complexfloating):
#
@override
@type_check_only
def __nep50__(self, below: clongdouble, above: _nt.co_number, /) -> clongdouble: ...
def __nep50__(self, into: clongdouble, from_: _nt.co_number, /) -> clongdouble: ...
@override
@type_check_only
def __nep50_rule2__(self, other: _AbstractInteger, /) -> clongdouble: ...
Expand Down Expand Up @@ -4968,7 +4966,7 @@ class object_(_RealMixin, generic[Any]):

#
@type_check_only
def __nep50__(self, below: object_, above: _nt.co_number | character, /) -> object_: ...
def __nep50__(self, into: object_, from_: _nt.co_number | character, /) -> object_: ...
@type_check_only
def __nep50_builtin__(self, /) -> tuple[_JustBuiltinScalar, object_]: ...

Expand Down Expand Up @@ -4998,7 +4996,7 @@ class bytes_(character[bytes], bytes): # type: ignore[misc]

#
@type_check_only
def __nep50__(self, below: bytes_ | object_, above: Never, /) -> bytes_: ...
def __nep50__(self, into: bytes_ | object_, from_: Never, /) -> bytes_: ...
@type_check_only
def __nep50_builtin__(self, /) -> tuple[_nt.JustBytes, bytes_]: ...

Expand All @@ -5015,7 +5013,7 @@ class str_(character[str], str): # type: ignore[misc]

#
@type_check_only
def __nep50__(self, below: str_ | object_, above: Never, /) -> str_: ...
def __nep50__(self, into: str_ | object_, from_: Never, /) -> str_: ...
@type_check_only
def __nep50_builtin__(self, /) -> tuple[_nt.JustStr, str_]: ...

Expand All @@ -5032,7 +5030,7 @@ class void(flexible[bytes | tuple[Any, ...]]): # type: ignore[misc] # pyright:

#
@type_check_only
def __nep50__(self, below: object_, from_: Never, /) -> void: ...
def __nep50__(self, into: object_, from_: Never, /) -> void: ...

#
@overload
Expand Down Expand Up @@ -5212,7 +5210,7 @@ class timedelta64(

#
@type_check_only
def __nep50__(self, below: timedelta64, above: _nt.co_integer, /) -> timedelta64: ...
def __nep50__(self, into: timedelta64, from_: _nt.co_integer, /) -> timedelta64: ...
@type_check_only
def __nep50_builtin__(self, /) -> tuple[int, timedelta64]: ...

Expand Down