Skip to content

Commit

Permalink
ENH: allow python scalars as inputs to result_type
Browse files Browse the repository at this point in the history
  • Loading branch information
ev-br committed Jan 7, 2025
1 parent 61bf3c1 commit 05b46ac
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
27 changes: 24 additions & 3 deletions array_api_strict/_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def isdtype(
else:
raise TypeError(f"'kind' must be a dtype, str, or tuple of dtypes and strs, not {type(kind).__name__}")

def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
def result_type(*arrays_and_dtypes: Union[Array, Dtype, int, float, complex, bool]) -> Dtype:
"""
Array API compatible wrapper for :py:func:`np.result_type <numpy.result_type>`.
Expand All @@ -208,19 +208,40 @@ def result_type(*arrays_and_dtypes: Union[Array, Dtype]) -> Dtype:
# too many extra type promotions like int64 + uint64 -> float64, and does
# value-based casting on scalar arrays.
A = []
scalars = []
for a in arrays_and_dtypes:
if isinstance(a, Array):
a = a.dtype
elif isinstance(a, (bool, int, float, complex)):
scalars.append(a)
elif isinstance(a, np.ndarray) or a not in _all_dtypes:
raise TypeError("result_type() inputs must be array_api arrays or dtypes")
A.append(a)

# remove python scalars
A = [a for a in A if not isinstance(a, (bool, int, float, complex))]

if len(A) == 0:
raise ValueError("at least one array or dtype is required")
elif len(A) == 1:
return A[0]
result = A[0]
else:
t = A[0]
for t2 in A[1:]:
t = _result_type(t, t2)
return t
result = t

if len(scalars) == 0:
return result

if get_array_api_strict_flags()['api_version'] <= '2023.12':
raise TypeError("result_type() inputs must be array_api arrays or dtypes")

# promote python scalars given the result_type for all arrays/dtypes
from ._creation_functions import empty
arr = empty(1, dtype=result)
for s in scalars:
x = arr._promote_scalar(s)
result = _result_type(x.dtype, result)

return result
23 changes: 21 additions & 2 deletions array_api_strict/tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import numpy as np

from .._creation_functions import asarray
from .._data_type_functions import astype, can_cast, isdtype
from .._data_type_functions import astype, can_cast, isdtype, result_type
from .._dtypes import (
bool, int8, int16, uint8, float64,
bool, int8, int16, uint8, float64, int64
)
from .._flags import set_array_api_strict_flags

Expand Down Expand Up @@ -70,3 +70,22 @@ def astype_device(api_version):
else:
pytest.raises(TypeError, lambda: astype(a, int8, device=None))
pytest.raises(TypeError, lambda: astype(a, int8, device=a.device))


@pytest.mark.parametrize("api_version", ['2023.12', '2024.12'])
def test_result_type_py_scalars(api_version):
if api_version <= '2023.12':
set_array_api_strict_flags(api_version=api_version)

with pytest.raises(TypeError):
result_type(int16, 3)
else:
with pytest.warns(UserWarning):
set_array_api_strict_flags(api_version=api_version)

assert result_type(int8, 3) == int8
assert result_type(uint8, 3) == uint8
assert result_type(float64, 3) == float64

with pytest.raises(TypeError):
result_type(int64, True)

0 comments on commit 05b46ac

Please sign in to comment.