Skip to content

Commit

Permalink
feat: add Form.is_equal_to (#2862)
Browse files Browse the repository at this point in the history
* feat: add is_equal_to to Form

* feat: add is_equal_to for content

* fix: don't permute recordtype

* fix: incorrect test!

* chore: apply pre-commit
  • Loading branch information
agoose77 authored Dec 1, 2023
1 parent 7f69758 commit b8dcee2
Show file tree
Hide file tree
Showing 28 changed files with 291 additions and 236 deletions.
15 changes: 10 additions & 5 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,10 +839,15 @@ def _to_backend(self, backend: Backend) -> Self:
parameters=self._parameters,
)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.valid_when == other.valid_when
and self.lsb_order == other.lsb_order
and self.mask.is_equal_to(other.mask, index_dtype, numpyarray)
and self.content.is_equal_to(other.content, index_dtype, numpyarray)
self._is_equal_to_generic(other, all_parameters)
and self._valid_when == other.valid_when
and self._lsb_order == other.lsb_order
and self._mask.is_equal_to(other.mask, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
13 changes: 9 additions & 4 deletions src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,9 +1182,14 @@ def _to_backend(self, backend: Backend) -> Self:
mask, content, valid_when=self._valid_when, parameters=self._parameters
)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.valid_when == other.valid_when
and self.mask.is_equal_to(other.mask, index_dtype, numpyarray)
and self.content.is_equal_to(other.content, index_dtype, numpyarray)
self._is_equal_to_generic(other, all_parameters)
and self._valid_when == other.valid_when
and self._mask.is_equal_to(other.mask, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
21 changes: 17 additions & 4 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._nplikes.typetracer import TypeTracer
from awkward._parameters import (
parameters_are_equal,
type_parameters_equal,
)
from awkward._regularize import is_integer_like, is_sized_iterable
Expand Down Expand Up @@ -1244,20 +1245,32 @@ def __deepcopy__(self, memo):
raise NotImplementedError

def is_equal_to(
self, other: Content, index_dtype: bool = True, numpyarray: bool = True
self,
other: Content,
index_dtype: bool = True,
numpyarray: bool = True,
*,
all_parameters: bool = False,
) -> bool:
return self._is_equal_to(other, index_dtype, numpyarray, all_parameters)

def _is_equal_to_generic(self, other: Content, all_parameters: bool) -> bool:
compare_parameters = (
parameters_are_equal if all_parameters else type_parameters_equal
)
return (
self.__class__ is other.__class__
and (
self.length is unknown_length
or other.length is unknown_length
or self.length == other.length
)
and type_parameters_equal(self._parameters, other._parameters)
and self._is_equal_to(other, index_dtype, numpyarray)
and compare_parameters(self._parameters, other._parameters)
)

def _is_equal_to(self, other: Self, index_dtype: bool, numpyarray: bool) -> bool:
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
raise NotImplementedError

def _repr(self, indent: str, pre: str, post: str) -> str:
Expand Down
6 changes: 4 additions & 2 deletions src/awkward/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,5 +448,7 @@ def _to_list(self, behavior, json_conversions):
def _to_backend(self, backend: Backend) -> Self:
return EmptyArray(backend=backend)

def _is_equal_to(self, other, index_dtype, numpyarray):
return True
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return self._is_equal_to_generic(other, all_parameters)
16 changes: 11 additions & 5 deletions src/awkward/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,11 +1155,6 @@ def _to_backend(self, backend: Backend) -> Self:
index = self._index.to_nplike(backend.index_nplike)
return IndexedArray(index, content, parameters=self._parameters)

def _is_equal_to(self, other, index_dtype, numpyarray):
return self.index.is_equal_to(
other.index, index_dtype, numpyarray
) and self.content.is_equal_to(other.content, index_dtype, numpyarray)

def _push_inside_record_or_project(self) -> Self | ak.contents.RecordArray:
if self.content.is_record:
return ak.contents.RecordArray(
Expand All @@ -1174,3 +1169,14 @@ def _push_inside_record_or_project(self) -> Self | ak.contents.RecordArray:
)
else:
return self.project()

def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self._is_equal_to_generic(other, all_parameters)
and self._index.is_equal_to(other.index, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
14 changes: 10 additions & 4 deletions src/awkward/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1755,7 +1755,13 @@ def _to_backend(self, backend: Backend) -> Self:
index = self._index.to_nplike(backend.index_nplike)
return IndexedOptionArray(index, content, parameters=self._parameters)

def _is_equal_to(self, other, index_dtype, numpyarray):
return self.index.is_equal_to(
other.index, index_dtype, numpyarray
) and self.content.is_equal_to(other.content, index_dtype, numpyarray)
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self._is_equal_to_generic(other, all_parameters)
and self._index.is_equal_to(other.index, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
13 changes: 9 additions & 4 deletions src/awkward/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1607,9 +1607,14 @@ def _to_backend(self, backend: Backend) -> Self:
stops = self._stops.to_nplike(backend.index_nplike)
return ListArray(starts, stops, content, parameters=self._parameters)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.starts.is_equal_to(other.starts, index_dtype, numpyarray)
and self.stops.is_equal_to(other.stops, index_dtype, numpyarray)
and self.content.is_equal_to(other.content, index_dtype, numpyarray)
self._is_equal_to_generic(other, all_parameters)
and self._starts.is_equal_to(other.starts, index_dtype, numpyarray)
and self._stops.is_equal_to(other.stops, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
14 changes: 10 additions & 4 deletions src/awkward/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2318,7 +2318,13 @@ def _awkward_strings_to_nonfinite(self, nonfinit_dict):

return content

def _is_equal_to(self, other, index_dtype, numpyarray):
return self.offsets.is_equal_to(
other.offsets, index_dtype, numpyarray
) and self.content.is_equal_to(other.content, index_dtype, numpyarray)
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self._is_equal_to_generic(other, all_parameters)
and self._offsets.is_equal_to(other.offsets, index_dtype, numpyarray)
and self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
)
30 changes: 18 additions & 12 deletions src/awkward/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from awkward._nplikes.jax import Jax
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy_like import IndexType, NumpyMetadata
from awkward._nplikes.shape import ShapeItem
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._nplikes.typetracer import TypeTracerArray
from awkward._parameters import (
parameters_intersect,
Expand Down Expand Up @@ -1375,18 +1375,24 @@ def _to_backend(self, backend: Backend) -> Self:
backend=backend,
)

def _is_equal_to(self, other, index_dtype, numpyarray):
if numpyarray:
return (
(
not self._backend.nplike.known_data
or self._backend.nplike.array_equal(self.data, other.data)
)
and self.dtype == other.dtype
and self.shape == other.shape
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return self._is_equal_to_generic(other, all_parameters) and (
not numpyarray
# dtypes agree
or self.dtype == other.dtype
# Contents agree
and (
not self._backend.nplike.known_data
or self._backend.nplike.array_equal(self.data, other.data)
)
# Shapes agree
and all(
x is unknown_length or y is unknown_length or x == y
for x, y in zip(self.shape, other.shape)
)
else:
return True
)

def _to_regular_primitive(self) -> ak.contents.RegularArray:
# A length-1 slice in each dimension
Expand Down
15 changes: 10 additions & 5 deletions src/awkward/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,12 +1291,17 @@ def _to_backend(self, backend: Backend) -> Self:
backend=backend,
)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.fields == other.fields
and len(self.contents) == len(other.contents)
self._is_equal_to_generic(other, all_parameters)
and self.is_tuple == other.is_tuple
and set(self.fields) == set(other.fields)
and all(
self.contents[i].is_equal_to(other.contents[i], index_dtype, numpyarray)
for i in range(len(self.contents))
content._is_equal_to(
other.content(field), index_dtype, numpyarray, all_parameters
)
for field, content in zip(self.fields, self._contents)
)
)
16 changes: 13 additions & 3 deletions src/awkward/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,17 @@ def _to_backend(self, backend: Backend) -> Self:
content, self._size, zeros_length=self._length, parameters=self._parameters
)

def _is_equal_to(self, other, index_dtype, numpyarray):
return self._size == other.size and self._content.is_equal_to(
other._content, index_dtype, numpyarray
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self._is_equal_to_generic(other, all_parameters)
and (
self._size is unknown_length
or other.size is unknown_length
or self._size == other.size
)
and self._content._is_equal_to(
other._content, index_dtype, numpyarray, all_parameters
)
)
15 changes: 9 additions & 6 deletions src/awkward/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,13 +1655,16 @@ def _to_backend(self, backend: Backend) -> Self:
parameters=self._parameters,
)

def _is_equal_to(self, other, index_dtype, numpyarray):
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return (
self.tags.is_equal_to(other.tags)
and self.index.is_equal_to(other.index, index_dtype, numpyarray)
and len(self.contents) == len(other.contents)
self._is_equal_to_generic(other, all_parameters)
and self._tags.is_equal_to(other.tags)
and self._index.is_equal_to(other.index, index_dtype, numpyarray)
and len(self._contents) == len(other.contents)
and all(
self.contents[i].is_equal_to(other.contents[i], index_dtype, numpyarray)
for i in range(len(self.contents))
content.is_equal_to(other.contents[i], index_dtype, numpyarray)
for i, content in enumerate(self.contents)
)
)
8 changes: 6 additions & 2 deletions src/awkward/contents/unmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,5 +584,9 @@ def _to_backend(self, backend: Backend) -> Self:
content = self._content.to_backend(backend)
return UnmaskedArray(content, parameters=self._parameters)

def _is_equal_to(self, other, index_dtype, numpyarray):
return self.content.is_equal_to(other.content, index_dtype, numpyarray)
def _is_equal_to(
self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool
) -> bool:
return self._content._is_equal_to(
other.content, index_dtype, numpyarray, all_parameters
)
25 changes: 10 additions & 15 deletions src/awkward/forms/bitmaskedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import awkward as ak
from awkward._meta.bitmaskedmeta import BitMaskedMeta
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._parameters import type_parameters_equal
from awkward._typing import DType, Iterator, Self, final
from awkward._typing import Any, DType, Iterator, Self, final
from awkward._util import UNSET
from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype

Expand Down Expand Up @@ -150,19 +149,6 @@ def type(self):
self._content.type, parameters=self._parameters
).simplify_option_union()

def __eq__(self, other):
if isinstance(other, BitMaskedForm):
return (
self._form_key == other._form_key
and self._mask == other._mask
and self._valid_when == other._valid_when
and self._lsb_order == other._lsb_order
and type_parameters_equal(self._parameters, other._parameters)
and self._content == other._content
)
else:
return False

def _columns(self, path, output, list_indicator):
self._content._columns(path, output, list_indicator)

Expand Down Expand Up @@ -215,3 +201,12 @@ def _expected_from_buffers(
yield (getkey(self, "mask"), index_to_dtype[self._mask])
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)

def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool:
return (
self._is_equal_to_generic(other, all_parameters, form_key)
and self._mask == other._mask
and self._valid_when == other._valid_when
and self._lsb_order == other._lsb_order
and self._content._is_equal_to(other._content, all_parameters, form_key)
)
23 changes: 9 additions & 14 deletions src/awkward/forms/bytemaskedform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import awkward as ak
from awkward._meta.bytemaskedmeta import ByteMaskedMeta
from awkward._nplikes.numpy_like import NumpyMetadata
from awkward._parameters import type_parameters_equal
from awkward._typing import DType, Iterator, Self, final
from awkward._typing import Any, DType, Iterator, Self, final
from awkward._util import UNSET
from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype

Expand Down Expand Up @@ -134,18 +133,6 @@ def type(self):
self._content.type, parameters=self._parameters
).simplify_option_union()

def __eq__(self, other):
if isinstance(other, ByteMaskedForm):
return (
self._form_key == other._form_key
and self._mask == other._mask
and self._valid_when == other._valid_when
and type_parameters_equal(self._parameters, other._parameters)
and self._content == other._content
)
else:
return False

def _columns(self, path, output, list_indicator):
self._content._columns(path, output, list_indicator)

Expand Down Expand Up @@ -185,3 +172,11 @@ def _expected_from_buffers(
yield (getkey(self, "mask"), index_to_dtype[self._mask])
if recursive:
yield from self._content._expected_from_buffers(getkey, recursive)

def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool:
return (
self._is_equal_to_generic(other, all_parameters, form_key)
and self._mask == other._mask
and self._valid_when == other._valid_when
and self._content._is_equal_to(other._content, all_parameters, form_key)
)
Loading

0 comments on commit b8dcee2

Please sign in to comment.