Skip to content

Commit

Permalink
Merge pull request #24707 from charris/backport-24705
Browse files Browse the repository at this point in the history
TYP: Add annotations for the py3.12 buffer protocol
  • Loading branch information
charris authored Sep 14, 2023
2 parents 2e84b15 + 92aab8c commit ca97802
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 23 deletions.
33 changes: 22 additions & 11 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import builtins
import sys
import os
import mmap
import ctypes as ct
Expand Down Expand Up @@ -1440,17 +1441,18 @@ _ShapeType = TypeVar("_ShapeType", bound=Any)
_ShapeType2 = TypeVar("_ShapeType2", bound=Any)
_NumberType = TypeVar("_NumberType", bound=number[Any])

# There is currently no exhaustive way to type the buffer protocol,
# as it is implemented exclusively in the C API (python/typing#593)
_SupportsBuffer = Union[
bytes,
bytearray,
memoryview,
_array.array[Any],
mmap.mmap,
NDArray[Any],
generic,
]
if sys.version_info >= (3, 12):
from collections.abc import Buffer as _SupportsBuffer
else:
_SupportsBuffer = (
bytes
| bytearray
| memoryview
| _array.array[Any]
| mmap.mmap
| NDArray[Any]
| generic
)

_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
Expand Down Expand Up @@ -1513,6 +1515,9 @@ class ndarray(_ArrayOrScalarCommon, Generic[_ShapeType, _DType_co]):
order: _OrderKACF = ...,
) -> _ArraySelf: ...

if sys.version_info >= (3, 12):
def __buffer__(self, flags: int, /) -> memoryview: ...

def __class_getitem__(self, item: Any) -> GenericAlias: ...

@overload
Expand Down Expand Up @@ -2570,6 +2575,9 @@ class generic(_ArrayOrScalarCommon):
@property
def flat(self: _ScalarType) -> flatiter[ndarray[Any, _dtype[_ScalarType]]]: ...

if sys.version_info >= (3, 12):
def __buffer__(self, flags: int, /) -> memoryview: ...

@overload
def astype(
self,
Expand Down Expand Up @@ -2772,6 +2780,9 @@ class object_(generic):
def __float__(self) -> float: ...
def __complex__(self) -> complex: ...

if sys.version_info >= (3, 12):
def __release_buffer__(self, buffer: memoryview, /) -> None: ...

# The `datetime64` constructors requires an object with the three attributes below,
# and thus supports datetime duck typing
class _DatetimeScalar(Protocol):
Expand Down
25 changes: 14 additions & 11 deletions numpy/_typing/_array_like.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import sys
from collections.abc import Collection, Callable, Sequence
from typing import Any, Protocol, Union, TypeVar, runtime_checkable

from numpy import (
ndarray,
dtype,
Expand Down Expand Up @@ -76,17 +78,18 @@ def __array_function__(
_NestedSequence[_T],
]

# TODO: support buffer protocols once
#
# https://bugs.python.org/issue27501
#
# is resolved. See also the mypy issue:
#
# https://github.com/python/typing/issues/593
ArrayLike = _DualArrayLike[
dtype[Any],
Union[bool, int, float, complex, str, bytes],
]
if sys.version_info >= (3, 12):
from collections.abc import Buffer

ArrayLike = Buffer | _DualArrayLike[
dtype[Any],
Union[bool, int, float, complex, str, bytes],
]
else:
ArrayLike = _DualArrayLike[
dtype[Any],
Union[bool, int, float, complex, str, bytes],
]

# `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
# given the casting rules `same_kind`
Expand Down
7 changes: 6 additions & 1 deletion numpy/array_api/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"PyCapsule",
]

import sys

from typing import (
Any,
Literal,
Expand Down Expand Up @@ -63,8 +65,11 @@ def __len__(self, /) -> int: ...
float64,
]]

if sys.version_info >= (3, 12):
from collections.abc import Buffer as SupportsBufferProtocol
else:
SupportsBufferProtocol = Any

SupportsBufferProtocol = Any
PyCapsule = Any

class SupportsDLPack(Protocol):
Expand Down
8 changes: 8 additions & 0 deletions numpy/typing/tests/data/reveal/array_constructors.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,11 @@ assert_type(np.stack([A, A], out=B), SubClass[np.float64])

assert_type(np.block([[A, A], [A, A]]), npt.NDArray[Any])
assert_type(np.block(C), npt.NDArray[Any])

if sys.version_info >= (3, 12):
from collections.abc import Buffer

def create_array(obj: npt.ArrayLike) -> npt.NDArray[Any]: ...

buffer: Buffer
assert_type(create_array(buffer), npt.NDArray[Any])

0 comments on commit ca97802

Please sign in to comment.