Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Form.is_equal_to #2862

Merged
merged 5 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,10 +839,15 @@ def _to_backend(self, backend: Backend) -> Self:
parameters=self._parameters,
)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.valid_when == other.valid_when
and self.lsb_order == other.lsb_order
and self.mask.is_equal_to(other.mask, index_dtype, numpyarray)
and self.content.is_equal_to(other.content, index_dtype, numpyarray)
self._is_equal_to_generic(other, all_parameters)
and self._valid_when == other.valid_when
and self._lsb_order == other.lsb_order
and self._mask.is_equal_to(other.mask, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
13 changes: 9 additions & 4 deletions src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,9 +1182,14 @@ def _to_backend(self, backend: Backend) -> Self:
mask, content, valid_when=self._valid_when, parameters=self._parameters
)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.valid_when == other.valid_when
and self.mask.is_equal_to(other.mask, index_dtype, numpyarray)
and self.content.is_equal_to(other.content, index_dtype, numpyarray)
self._is_equal_to_generic(other, all_parameters)
and self._valid_when == other.valid_when
and self._mask.is_equal_to(other.mask, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
21 changes: 17 additions & 4 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._nplikes.typetracer import TypeTracer
from awkward._parameters import (
parameters_are_equal,
type_parameters_equal,
)
from awkward._regularize import is_integer_like, is_sized_iterable
Expand Down Expand Up @@ -1244,20 +1245,32 @@ def __deepcopy__(self, memo):
raise NotImplementedError

def is_equal_to(
self, other: Content, index_dtype: bool = True, numpyarray: bool = True
self,
other: Content,
index_dtype: bool = True,
numpyarray: bool = True,
*,
all_parameters: bool = False,
) -> bool:
return self._is_equal_to(other, index_dtype, numpyarray, all_parameters)

def _is_equal_to_generic(self, other: Content, all_parameters: bool) -> bool:
compare_parameters = (
parameters_are_equal if all_parameters else type_parameters_equal
)
return (
self.__class__ is other.__class__
and (
self.length is unknown_length
or other.length is unknown_length
or self.length == other.length
)
and type_parameters_equal(self._parameters, other._parameters)
and self._is_equal_to(other, index_dtype, numpyarray)
and compare_parameters(self._parameters, other._parameters)
)

def _is_equal_to(self, other: Self, index_dtype: bool, numpyarray: bool) -> bool:
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
raise NotImplementedError

def _repr(self, indent: str, pre: str, post: str) -> str:
Expand Down
6 changes: 4 additions & 2 deletions src/awkward/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,5 +448,7 @@ def _to_list(self, behavior, json_conversions):
def _to_backend(self, backend: Backend) -> Self:
return EmptyArray(backend=backend)

def _is_equal_to(self, other, index_dtype, numpyarray):
return True
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return self._is_equal_to_generic(other, all_parameters)
16 changes: 11 additions & 5 deletions src/awkward/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,11 +1155,6 @@ def _to_backend(self, backend: Backend) -> Self:
index = self._index.to_nplike(backend.index_nplike)
return IndexedArray(index, content, parameters=self._parameters)

def _is_equal_to(self, other, index_dtype, numpyarray):
return self.index.is_equal_to(
other.index, index_dtype, numpyarray
) and self.content.is_equal_to(other.content, index_dtype, numpyarray)

def _push_inside_record_or_project(self) -> Self | ak.contents.RecordArray:
if self.content.is_record:
return ak.contents.RecordArray(
Expand All @@ -1174,3 +1169,14 @@ def _push_inside_record_or_project(self) -> Self | ak.contents.RecordArray:
)
else:
return self.project()

def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self._is_equal_to_generic(other, all_parameters)
and self._index.is_equal_to(other.index, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
14 changes: 10 additions & 4 deletions src/awkward/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,7 +1755,13 @@ def _to_backend(self, backend: Backend) -> Self:
index = self._index.to_nplike(backend.index_nplike)
return IndexedOptionArray(index, content, parameters=self._parameters)

def _is_equal_to(self, other, index_dtype, numpyarray):
return self.index.is_equal_to(
other.index, index_dtype, numpyarray
) and self.content.is_equal_to(other.content, index_dtype, numpyarray)
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self._is_equal_to_generic(other, all_parameters)
and self._index.is_equal_to(other.index, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
13 changes: 9 additions & 4 deletions src/awkward/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,9 +1607,14 @@ def _to_backend(self, backend: Backend) -> Self:
stops = self._stops.to_nplike(backend.index_nplike)
return ListArray(starts, stops, content, parameters=self._parameters)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.starts.is_equal_to(other.starts, index_dtype, numpyarray)
and self.stops.is_equal_to(other.stops, index_dtype, numpyarray)
and self.content.is_equal_to(other.content, index_dtype, numpyarray)
self._is_equal_to_generic(other, all_parameters)
and self._starts.is_equal_to(other.starts, index_dtype, numpyarray)
and self._stops.is_equal_to(other.stops, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
14 changes: 10 additions & 4 deletions src/awkward/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,13 @@ def _awkward_strings_to_nonfinite(self, nonfinit_dict):

return content

def _is_equal_to(self, other, index_dtype, numpyarray):
return self.offsets.is_equal_to(
other.offsets, index_dtype, numpyarray
) and self.content.is_equal_to(other.content, index_dtype, numpyarray)
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self._is_equal_to_generic(other, all_parameters)
and self._offsets.is_equal_to(other.offsets, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
30 changes: 18 additions & 12 deletions src/awkward/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from awkward._nplikes.jax import Jax
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import IndexType, NumpyMetadata
from awkward._nplikes.shape import ShapeItem
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._nplikes.typetracer import TypeTracerArray
from awkward._parameters import (
parameters_intersect,
Expand Down Expand Up @@ -1375,18 +1375,24 @@ def _to_backend(self, backend: Backend) -> Self:
backend=backend,
)

def _is_equal_to(self, other, index_dtype, numpyarray):
if numpyarray:
return (
(
not self._backend.nplike.known_data
or self._backend.nplike.array_equal(self.data, other.data)
)
and self.dtype == other.dtype
and self.shape == other.shape
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return self._is_equal_to_generic(other, all_parameters) and (
not numpyarray
# dtypes agree
or self.dtype == other.dtype
# Contents agree
and (
not self._backend.nplike.known_data
or self._backend.nplike.array_equal(self.data, other.data)
)
# Shapes agree
and all(
x is unknown_length or y is unknown_length or x == y
for x, y in zip(self.shape, other.shape)
)
else:
return True
)

def _to_regular_primitive(self) -> ak.contents.RegularArray:
# A length-1 slice in each dimension
Expand Down
15 changes: 10 additions & 5 deletions src/awkward/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,12 +1291,17 @@ def _to_backend(self, backend: Backend) -> Self:
backend=backend,
)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.fields == other.fields
and len(self.contents) == len(other.contents)
self._is_equal_to_generic(other, all_parameters)
and self.is_tuple == other.is_tuple
and set(self.fields) == set(other.fields)
and all(
self.contents[i].is_equal_to(other.contents[i], index_dtype, numpyarray)
for i in range(len(self.contents))
content._is_equal_to(
other.content(field), index_dtype, numpyarray, all_parameters
)
for field, content in zip(self.fields, self._contents)
)
)
16 changes: 13 additions & 3 deletions src/awkward/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,17 @@ def _to_backend(self, backend: Backend) -> Self:
content, self._size, zeros_length=self._length, parameters=self._parameters
)

def _is_equal_to(self, other, index_dtype, numpyarray):
return self._size == other.size and self._content.is_equal_to(
other._content, index_dtype, numpyarray
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self._is_equal_to_generic(other, all_parameters)
and (
self._size is unknown_length
or other.size is unknown_length
or self._size == other.size
)
and self._content._is_equal_to(
other._content, index_dtype, numpyarray, all_parameters
)
)
15 changes: 9 additions & 6 deletions src/awkward/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,13 +1655,16 @@ def _to_backend(self, backend: Backend) -> Self:
parameters=self._parameters,
)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.tags.is_equal_to(other.tags)
and self.index.is_equal_to(other.index, index_dtype, numpyarray)
and len(self.contents) == len(other.contents)
self._is_equal_to_generic(other, all_parameters)
and self._tags.is_equal_to(other.tags)
and self._index.is_equal_to(other.index, index_dtype, numpyarray)
and len(self._contents) == len(other.contents)
and all(
self.contents[i].is_equal_to(other.contents[i], index_dtype, numpyarray)
for i in range(len(self.contents))
content.is_equal_to(other.contents[i], index_dtype, numpyarray)
for i, content in enumerate(self.contents)
)
)
8 changes: 6 additions & 2 deletions src/awkward/contents/unmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,5 +584,9 @@ def _to_backend(self, backend: Backend) -> Self:
content = self._content.to_backend(backend)
return UnmaskedArray(content, parameters=self._parameters)

def _is_equal_to(self, other, index_dtype, numpyarray):
return self.content.is_equal_to(other.content, index_dtype, numpyarray)
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
25 changes: 10 additions & 15 deletions src/awkward/forms/bitmaskedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import awkward as ak
from awkward._meta.bitmaskedmeta import BitMaskedMeta
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._parameters import type_parameters_equal
from awkward._typing import DType, Iterator, Self, final
from awkward._typing import Any, DType, Iterator, Self, final
from awkward._util import UNSET
from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype

Expand Down Expand Up @@ -150,19 +149,6 @@ def type(self):
self._content.type, parameters=self._parameters
).simplify_option_union()

def __eq__(self, other):
if isinstance(other, BitMaskedForm):
return (
self._form_key == other._form_key
and self._mask == other._mask
and self._valid_when == other._valid_when
and self._lsb_order == other._lsb_order
and type_parameters_equal(self._parameters, other._parameters)
and self._content == other._content
)
else:
return False

def _columns(self, path, output, list_indicator):
self._content._columns(path, output, list_indicator)

Expand Down Expand Up @@ -215,3 +201,12 @@ def _expected_from_buffers(
yield (getkey(self, "mask"), index_to_dtype[self._mask])
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)

def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool:
return (
self._is_equal_to_generic(other, all_parameters, form_key)
and self._mask == other._mask
and self._valid_when == other._valid_when
and self._lsb_order == other._lsb_order
and self._content._is_equal_to(other._content, all_parameters, form_key)
)
23 changes: 9 additions & 14 deletions src/awkward/forms/bytemaskedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import awkward as ak
from awkward._meta.bytemaskedmeta import ByteMaskedMeta
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._parameters import type_parameters_equal
from awkward._typing import DType, Iterator, Self, final
from awkward._typing import Any, DType, Iterator, Self, final
from awkward._util import UNSET
from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype

Expand Down Expand Up @@ -134,18 +133,6 @@ def type(self):
self._content.type, parameters=self._parameters
).simplify_option_union()

def __eq__(self, other):
if isinstance(other, ByteMaskedForm):
return (
self._form_key == other._form_key
and self._mask == other._mask
and self._valid_when == other._valid_when
and type_parameters_equal(self._parameters, other._parameters)
and self._content == other._content
)
else:
return False

def _columns(self, path, output, list_indicator):
self._content._columns(path, output, list_indicator)

Expand Down Expand Up @@ -185,3 +172,11 @@ def _expected_from_buffers(
yield (getkey(self, "mask"), index_to_dtype[self._mask])
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)

def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool:
return (
self._is_equal_to_generic(other, all_parameters, form_key)
and self._mask == other._mask
and self._valid_when == other._valid_when
and self._content._is_equal_to(other._content, all_parameters, form_key)
)
Loading