Skip to content

Commit

Permalink
Feat: add missing _like methods to TypeTracer (#1505)
Browse files Browse the repository at this point in the history
* Fix: use `None` instead of `unset`

We should accept `None`

* Feat: add simple _like typetracer methods

* Fix: support list-of-primitive in `TypeTracerArray.from_array`

* Test: test new `_like` functions

* Refactor: simplify `from_array`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor: be noisy about invalid types

* Fix: test equality for object dtype

* Refactor: cleanup test names

* Test: ensure object dtypes are caught

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
agoose77 and pre-commit-ci[bot] authored Jun 16, 2022
1 parent 3006f01 commit 0c6ae03
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 21 deletions.
56 changes: 35 additions & 21 deletions src/awkward/_v2/_typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,22 @@ class TypeTracerArray:
def from_array(cls, array, dtype=None):
if isinstance(array, ak._v2.index.Index):
array = array.data

# not array-like
if not hasattr(array, "shape"):
sequence = list(array)
array = numpy.array(sequence)
if array.dtype == np.dtype("O"):
raise ak._v2._util.error(
ValueError(
f"bug in Awkward Array: attempt to construct `TypeTracerArray` "
f"from a sequence of non-primitive types: {sequence}"
)
)

if dtype is None:
dtype = array.dtype

return cls(dtype, shape=array.shape)

def __init__(self, dtype, shape=None):
Expand Down Expand Up @@ -432,9 +446,6 @@ def copy(self):
return self


unset = object()


class TypeTracer(ak.nplike.NumpyLike):
known_data = False
known_shape = False
Expand Down Expand Up @@ -483,21 +494,21 @@ def raw(self, array, nplike):

############################ array creation

def array(self, data, dtype=unset, **kwargs):
def array(self, data, dtype=None, **kwargs):
# data[, dtype=[, copy=]]
if dtype is unset:
if dtype is None:
dtype = data.dtype
return TypeTracerArray.from_array(data, dtype=dtype)

def asarray(self, array, dtype=unset, **kwargs):
def asarray(self, array, dtype=None, **kwargs):
# array[, dtype=][, order=]
if dtype is unset:
if dtype is None:
dtype = array.dtype
return TypeTracerArray.from_array(array, dtype=dtype)

def ascontiguousarray(self, array, dtype=unset, **kwargs):
def ascontiguousarray(self, array, dtype=None, **kwargs):
# array[, dtype=]
if dtype is unset:
if dtype is None:
dtype = array.dtype
return TypeTracerArray.from_array(array, dtype=dtype)

Expand All @@ -520,23 +531,26 @@ def empty(self, shape, dtype=np.float64, **kwargs):
# shape/len[, dtype=]
return TypeTracerArray(dtype, shape)

def full(self, shape, value, dtype=unset, **kwargs):
def full(self, shape, value, dtype=None, **kwargs):
# shape/len, value[, dtype=]
if dtype is unset:
if dtype is None:
dtype = numpy.array(value).dtype
return TypeTracerArray(dtype, shape)

def zeros_like(self, *args, **kwargs):
# array
raise ak._v2._util.error(NotImplementedError)
def zeros_like(self, a, dtype=None, **kwargs):
if dtype is None:
dtype = a.dtype

def ones_like(self, *args, **kwargs):
# array
raise ak._v2._util.error(NotImplementedError)
if isinstance(a, UnknownScalar):
return UnknownScalar(dtype)

def full_like(self, *args, **kwargs):
# array, fill_value
raise ak._v2._util.error(NotImplementedError)
return TypeTracerArray(dtype, a.shape)

def ones_like(self, a, dtype=None, **kwargs):
return self.zeros_like(a, dtype)

def full_like(self, a, fill_value, dtype=None, **kwargs):
return self.zeros_like(a, dtype)

def arange(self, *args, **kwargs):
# stop[, dtype=]
Expand Down Expand Up @@ -854,7 +868,7 @@ def argmax(self, *args, **kwargs):
raise ak._v2._util.error(NotImplementedError)

def array_str(
self, array, max_line_width=unset, precision=unset, suppress_small=unset
self, array, max_line_width=None, precision=None, suppress_small=None
):
# array, max_line_width, precision=None, suppress_small=None
return "[?? ... ??]"
Expand Down
48 changes: 48 additions & 0 deletions tests/v2/test_1504-typetracer-like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import awkward as ak
import numpy as np
import pytest


typetracer = ak._v2._typetracer.TypeTracer.instance()


@pytest.mark.parametrize("dtype", [np.float64, np.int64, np.uint8, None])
@pytest.mark.parametrize("like_dtype", [np.float64, np.int64, np.uint8, None])
def test_ones_like(dtype, like_dtype):
array = ak._v2.contents.numpyarray.NumpyArray(
np.array([99, 88, 77, 66, 66], dtype=dtype)
)
ones = ak._v2.ones_like(array.typetracer, dtype=like_dtype, highlevel=False)
assert ones.typetracer.shape == array.shape
assert ones.typetracer.dtype == like_dtype or array.dtype


@pytest.mark.parametrize("dtype", [np.float64, np.int64, np.uint8, None])
@pytest.mark.parametrize("like_dtype", [np.float64, np.int64, np.uint8, None])
def test_zeros_like(dtype, like_dtype):
array = ak._v2.contents.numpyarray.NumpyArray(
np.array([99, 88, 77, 66, 66], dtype=dtype)
)

full = ak._v2.zeros_like(array.typetracer, dtype=like_dtype, highlevel=False)
assert full.typetracer.shape == array.shape
assert full.typetracer.dtype == like_dtype or array.dtype


@pytest.mark.parametrize("dtype", [np.float64, np.int64, np.uint8, None])
@pytest.mark.parametrize("like_dtype", [np.float64, np.int64, np.uint8, None])
@pytest.mark.parametrize("value", [1.0, -20, np.iinfo(np.uint64).max])
def test_full_like(dtype, like_dtype, value):
array = ak._v2.contents.numpyarray.NumpyArray(
np.array([99, 88, 77, 66, 66], dtype=dtype)
)
full = ak._v2.full_like(array.typetracer, value, dtype=like_dtype, highlevel=False)
assert full.typetracer.shape == array.shape
assert full.typetracer.dtype == like_dtype or array.dtype


def test_full_like_cast():
with pytest.raises(ValueError):
ak._v2._typetracer.TypeTracerArray.from_array(
[1, ak._v2._typetracer.UnknownScalar(np.uint8)]
)

0 comments on commit 0c6ae03

Please sign in to comment.