Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "multi device" support #59

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions array_api_strict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@

__all__ += ["all", "any"]

from ._array_object import Device
__all__ += ["Device"]

# Helper functions that are not part of the standard

from ._flags import (
Expand Down
104 changes: 77 additions & 27 deletions array_api_strict/_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,26 @@

import numpy as np

# Placeholder object to represent the "cpu" device (the only device NumPy
# supports).
class _cpu_device:
class Device:
def __init__(self, device="CPU_DEVICE"):
if device not in ("CPU_DEVICE", "device1", "device2"):
raise ValueError(f"The device '{device}' is not a valid choice.")
self._device = device

def __repr__(self):
return "CPU_DEVICE"
return f"array_api_strict.Device('{self._device}')"

def __eq__(self, other):
if not isinstance(other, Device):
return False
return self._device == other._device

def __hash__(self):
return hash(("Device", self._device))


CPU_DEVICE = _cpu_device()
CPU_DEVICE = Device()
ALL_DEVICES = (CPU_DEVICE, Device("device1"), Device("device2"))

_default = object()

Expand All @@ -73,7 +86,7 @@ class Array:
# Use a custom constructor instead of __init__, as manually initializing
# this class is not supported API.
@classmethod
def _new(cls, x, /):
def _new(cls, x, /, device=None):
"""
This is a private method for initializing the array API Array
object.
Expand All @@ -95,6 +108,9 @@ def _new(cls, x, /):
)
obj._array = x
obj._dtype = _dtype
if device is None:
device = CPU_DEVICE
obj._device = device
return obj

# Prevent Array() from working
Expand All @@ -116,7 +132,11 @@ def __repr__(self: Array, /) -> str:
"""
Performs the operation __repr__.
"""
suffix = f", dtype={self.dtype})"
suffix = f", dtype={self.dtype}"
if self.device != CPU_DEVICE:
suffix += f", device={self.device})"
else:
suffix += ")"
if 0 in self.shape:
prefix = "empty("
mid = str(self.shape)
Expand All @@ -134,6 +154,8 @@ def __array__(self, dtype: None | np.dtype[Any] = None, copy: None | bool = None
will be present in other implementations.

"""
if self._device != CPU_DEVICE:
raise RuntimeError(f"Can not convert array on the '{self._device}' device to a Numpy array.")
# copy keyword is new in 2.0.0; for older versions don't use it
# retry without that keyword.
if np.__version__[0] < '2':
Expand Down Expand Up @@ -193,6 +215,15 @@ def _check_allowed_dtypes(self, other: bool | int | float | Array, dtype_categor

return other

def _check_device(self, other):
"""Check that other is on a device compatible with the current array"""
if isinstance(other, (int, complex, float, bool)):
return other
elif isinstance(other, Array):
if self.device != other.device:
raise RuntimeError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
return other

# Helper function to match the type promotion rules in the spec
def _promote_scalar(self, scalar):
"""
Expand Down Expand Up @@ -468,23 +499,25 @@ def __add__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __add__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__add__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__add__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __and__(self: Array, other: Union[int, bool, Array], /) -> Array:
"""
Performs the operation __and__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__and__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__and__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __array_namespace__(
self: Array, /, *, api_version: Optional[str] = None
Expand Down Expand Up @@ -568,14 +601,15 @@ def __eq__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
Performs the operation __eq__.
"""
other = self._check_device(other)
# Even though "all" dtypes are allowed, we still require them to be
# promotable with each other.
other = self._check_allowed_dtypes(other, "all", "__eq__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__eq__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __float__(self: Array, /) -> float:
"""
Expand All @@ -593,23 +627,25 @@ def __floordiv__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __floordiv__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__floordiv__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__floordiv__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __ge__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __ge__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__ge__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__ge__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __getitem__(
self: Array,
Expand All @@ -625,19 +661,21 @@ def __getitem__(
"""
Performs the operation __getitem__.
"""
# XXX Does key have to be on the same device? Is there an exception for CPU_DEVICE?
# Note: Only indices required by the spec are allowed. See the
# docstring of _validate_index
self._validate_index(key)
if isinstance(key, Array):
# Indexing self._array with array_api_strict arrays can be erroneous
key = key._array
res = self._array.__getitem__(key)
return self._new(res)
return self._new(res, device=self.device)

def __gt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __gt__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__gt__")
if other is NotImplemented:
return other
Expand Down Expand Up @@ -671,7 +709,7 @@ def __invert__(self: Array, /) -> Array:
if self.dtype not in _integer_or_boolean_dtypes:
raise TypeError("Only integer or boolean dtypes are allowed in __invert__")
res = self._array.__invert__()
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __iter__(self: Array, /):
"""
Expand All @@ -686,85 +724,92 @@ def __iter__(self: Array, /):
# define __iter__, but it doesn't disallow it. The default Python
# behavior is to implement iter as a[0], a[1], ... when __getitem__ is
# implemented, which implies iteration on 1-D arrays.
return (Array._new(i) for i in self._array)
return (Array._new(i, device=self.device) for i in self._array)

def __le__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __le__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__le__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__le__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __lshift__(self: Array, other: Union[int, Array], /) -> Array:
"""
Performs the operation __lshift__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "integer", "__lshift__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__lshift__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __lt__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __lt__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__lt__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__lt__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __matmul__(self: Array, other: Array, /) -> Array:
"""
Performs the operation __matmul__.
"""
other = self._check_device(other)
# matmul is not defined for scalars, but without this, we may get
# the wrong error message from asarray.
other = self._check_allowed_dtypes(other, "numeric", "__matmul__")
if other is NotImplemented:
return other
res = self._array.__matmul__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __mod__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __mod__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "real numeric", "__mod__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__mod__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __mul__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __mul__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__mul__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__mul__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __ne__(self: Array, other: Union[int, float, bool, Array], /) -> Array:
"""
Performs the operation __ne__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "all", "__ne__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__ne__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __neg__(self: Array, /) -> Array:
"""
Expand All @@ -773,18 +818,19 @@ def __neg__(self: Array, /) -> Array:
if self.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in __neg__")
res = self._array.__neg__()
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __or__(self: Array, other: Union[int, bool, Array], /) -> Array:
"""
Performs the operation __or__.
"""
other = self._check_device(other)
other = self._check_allowed_dtypes(other, "integer or boolean", "__or__")
if other is NotImplemented:
return other
self, other = self._normalize_two_args(self, other)
res = self._array.__or__(other._array)
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __pos__(self: Array, /) -> Array:
"""
Expand All @@ -793,14 +839,15 @@ def __pos__(self: Array, /) -> Array:
if self.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in __pos__")
res = self._array.__pos__()
return self.__class__._new(res)
return self.__class__._new(res, device=self.device)

def __pow__(self: Array, other: Union[int, float, Array], /) -> Array:
"""
Performs the operation __pow__.
"""
from ._elementwise_functions import pow

other = self._check_device(other)
other = self._check_allowed_dtypes(other, "numeric", "__pow__")
if other is NotImplemented:
return other
Expand Down Expand Up @@ -1154,8 +1201,11 @@ def __rxor__(self: Array, other: Union[int, bool, Array], /) -> Array:
def to_device(self: Array, device: Device, /, stream: None = None) -> Array:
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == CPU_DEVICE:
if device == self._device:
return self
elif isinstance(device, Device):
arr = np.asarray(self._array, copy=True)
return self.__class__._new(arr, device=device)
raise ValueError(f"Unsupported device {device!r}")

@property
Expand All @@ -1169,7 +1219,7 @@ def dtype(self) -> Dtype:

@property
def device(self) -> Device:
return CPU_DEVICE
return self._device

# Note: mT is new in array API spec (see matrix_transpose)
@property
Expand Down
Loading
Loading