diff --git a/src/awkward/contents/bitmaskedarray.py b/src/awkward/contents/bitmaskedarray.py index 19c677ad7d..6790ac7bc0 100644 --- a/src/awkward/contents/bitmaskedarray.py +++ b/src/awkward/contents/bitmaskedarray.py @@ -839,10 +839,15 @@ def _to_backend(self, backend: Backend) -> Self: parameters=self._parameters, ) - def _is_equal_to(self, other, index_dtype, numpyarray): + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: return ( - self.valid_when == other.valid_when - and self.lsb_order == other.lsb_order - and self.mask.is_equal_to(other.mask, index_dtype, numpyarray) - and self.content.is_equal_to(other.content, index_dtype, numpyarray) + self._is_equal_to_generic(other, all_parameters) + and self._valid_when == other.valid_when + and self._lsb_order == other.lsb_order + and self._mask.is_equal_to(other.mask, index_dtype, numpyarray) + and self._content._is_equal_to( + other.content, index_dtype, numpyarray, all_parameters + ) ) diff --git a/src/awkward/contents/bytemaskedarray.py b/src/awkward/contents/bytemaskedarray.py index f445a5f242..5738ae3803 100644 --- a/src/awkward/contents/bytemaskedarray.py +++ b/src/awkward/contents/bytemaskedarray.py @@ -1182,9 +1182,14 @@ def _to_backend(self, backend: Backend) -> Self: mask, content, valid_when=self._valid_when, parameters=self._parameters ) - def _is_equal_to(self, other, index_dtype, numpyarray): + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: return ( - self.valid_when == other.valid_when - and self.mask.is_equal_to(other.mask, index_dtype, numpyarray) - and self.content.is_equal_to(other.content, index_dtype, numpyarray) + self._is_equal_to_generic(other, all_parameters) + and self._valid_when == other.valid_when + and self._mask.is_equal_to(other.mask, index_dtype, numpyarray) + and self._content._is_equal_to( + other.content, index_dtype, numpyarray, all_parameters + ) ) diff --git a/src/awkward/contents/content.py b/src/awkward/contents/content.py index c411e0a519..3812fec414 100644 --- a/src/awkward/contents/content.py +++ b/src/awkward/contents/content.py @@ -25,6 +25,7 @@ from awkward._nplikes.shape import ShapeItem, unknown_length from awkward._nplikes.typetracer import TypeTracer from awkward._parameters import ( + parameters_are_equal, type_parameters_equal, ) from awkward._regularize import is_integer_like, is_sized_iterable @@ -1244,8 +1245,19 @@ def __deepcopy__(self, memo): raise NotImplementedError def is_equal_to( - self, other: Content, index_dtype: bool = True, numpyarray: bool = True + self, + other: Content, + index_dtype: bool = True, + numpyarray: bool = True, + *, + all_parameters: bool = False, ) -> bool: + return self._is_equal_to(other, index_dtype, numpyarray, all_parameters) + + def _is_equal_to_generic(self, other: Content, all_parameters: bool) -> bool: + compare_parameters = ( + parameters_are_equal if all_parameters else type_parameters_equal + ) return ( self.__class__ is other.__class__ and ( @@ -1253,11 +1265,12 @@ def is_equal_to( or other.length is unknown_length or self.length == other.length ) - and type_parameters_equal(self._parameters, other._parameters) - and self._is_equal_to(other, index_dtype, numpyarray) + and compare_parameters(self._parameters, other._parameters) ) - def _is_equal_to(self, other: Self, index_dtype: bool, numpyarray: bool) -> bool: + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: raise NotImplementedError def _repr(self, indent: str, pre: str, post: str) -> str: diff --git a/src/awkward/contents/emptyarray.py b/src/awkward/contents/emptyarray.py index 8103661156..831d8c3810 100644 --- a/src/awkward/contents/emptyarray.py +++ b/src/awkward/contents/emptyarray.py @@ -448,5 +448,7 @@ def _to_list(self, behavior, json_conversions): def _to_backend(self, backend: Backend) -> Self: return EmptyArray(backend=backend) - def _is_equal_to(self, other, index_dtype, numpyarray): - return True + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: + return self._is_equal_to_generic(other, all_parameters) diff --git a/src/awkward/contents/indexedarray.py b/src/awkward/contents/indexedarray.py index f6d2210fd2..ee5cb28c10 100644 --- a/src/awkward/contents/indexedarray.py +++ b/src/awkward/contents/indexedarray.py @@ -1155,11 +1155,6 @@ def _to_backend(self, backend: Backend) -> Self: index = self._index.to_nplike(backend.index_nplike) return IndexedArray(index, content, parameters=self._parameters) - def _is_equal_to(self, other, index_dtype, numpyarray): - return self.index.is_equal_to( - other.index, index_dtype, numpyarray - ) and self.content.is_equal_to(other.content, index_dtype, numpyarray) - def _push_inside_record_or_project(self) -> Self | ak.contents.RecordArray: if self.content.is_record: return ak.contents.RecordArray( @@ -1174,3 +1169,14 @@ def _push_inside_record_or_project(self) -> Self | ak.contents.RecordArray: ) else: return self.project() + + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters) + and self._index.is_equal_to(other.index, index_dtype, numpyarray) + and self._content._is_equal_to( + other.content, index_dtype, numpyarray, all_parameters + ) + ) diff --git a/src/awkward/contents/indexedoptionarray.py b/src/awkward/contents/indexedoptionarray.py index 39c3279cb6..396aece3a1 100644 --- a/src/awkward/contents/indexedoptionarray.py +++ b/src/awkward/contents/indexedoptionarray.py @@ -1755,7 +1755,13 @@ def _to_backend(self, backend: Backend) -> Self: index = self._index.to_nplike(backend.index_nplike) return IndexedOptionArray(index, content, parameters=self._parameters) - def _is_equal_to(self, other, index_dtype, numpyarray): - return self.index.is_equal_to( - other.index, index_dtype, numpyarray - ) and self.content.is_equal_to(other.content, index_dtype, numpyarray) + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters) + and self._index.is_equal_to(other.index, index_dtype, numpyarray) + and self._content._is_equal_to( + other.content, index_dtype, numpyarray, all_parameters + ) + ) diff --git a/src/awkward/contents/listarray.py b/src/awkward/contents/listarray.py index 8d917b8d0c..d4751cef9d 100644 --- a/src/awkward/contents/listarray.py +++ b/src/awkward/contents/listarray.py @@ -1607,9 +1607,14 @@ def _to_backend(self, backend: Backend) -> Self: stops = self._stops.to_nplike(backend.index_nplike) return ListArray(starts, stops, content, parameters=self._parameters) - def _is_equal_to(self, other, index_dtype, numpyarray): + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: return ( - self.starts.is_equal_to(other.starts, index_dtype, numpyarray) - and self.stops.is_equal_to(other.stops, index_dtype, numpyarray) - and self.content.is_equal_to(other.content, index_dtype, numpyarray) + self._is_equal_to_generic(other, all_parameters) + and self._starts.is_equal_to(other.starts, index_dtype, numpyarray) + and self._stops.is_equal_to(other.stops, index_dtype, numpyarray) + and self._content._is_equal_to( + other.content, index_dtype, numpyarray, all_parameters + ) ) diff --git a/src/awkward/contents/listoffsetarray.py b/src/awkward/contents/listoffsetarray.py index 37473a5c21..1b54e108a9 100644 --- a/src/awkward/contents/listoffsetarray.py +++ b/src/awkward/contents/listoffsetarray.py @@ -2318,7 +2318,13 @@ def _awkward_strings_to_nonfinite(self, nonfinit_dict): return content - def _is_equal_to(self, other, index_dtype, numpyarray): - return self.offsets.is_equal_to( - other.offsets, index_dtype, numpyarray - ) and self.content.is_equal_to(other.content, index_dtype, numpyarray) + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters) + and self._offsets.is_equal_to(other.offsets, index_dtype, numpyarray) + and self._content._is_equal_to( + other.content, index_dtype, numpyarray, all_parameters + ) + ) diff --git a/src/awkward/contents/numpyarray.py b/src/awkward/contents/numpyarray.py index 5d2bb86506..5bf111d2db 100644 --- a/src/awkward/contents/numpyarray.py +++ b/src/awkward/contents/numpyarray.py @@ -18,7 +18,7 @@ from awkward._nplikes.jax import Jax from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpy_like import IndexType, NumpyMetadata -from awkward._nplikes.shape import ShapeItem +from awkward._nplikes.shape import ShapeItem, unknown_length from awkward._nplikes.typetracer import TypeTracerArray from awkward._parameters import ( parameters_intersect, @@ -1375,18 +1375,24 @@ def _to_backend(self, backend: Backend) -> Self: backend=backend, ) - def _is_equal_to(self, other, index_dtype, numpyarray): - if numpyarray: - return ( - ( - not self._backend.nplike.known_data - or self._backend.nplike.array_equal(self.data, other.data) - ) - and self.dtype == other.dtype - and self.shape == other.shape + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: + return self._is_equal_to_generic(other, all_parameters) and ( + not numpyarray + # dtypes agree + or self.dtype == other.dtype + # Contents agree + and ( + not self._backend.nplike.known_data + or self._backend.nplike.array_equal(self.data, other.data) + ) + # Shapes agree + and all( + x is unknown_length or y is unknown_length or x == y + for x, y in zip(self.shape, other.shape) ) - else: - return True + ) def _to_regular_primitive(self) -> ak.contents.RegularArray: # A length-1 slice in each dimension diff --git a/src/awkward/contents/recordarray.py b/src/awkward/contents/recordarray.py index 17e914834c..30697f6f8f 100644 --- a/src/awkward/contents/recordarray.py +++ b/src/awkward/contents/recordarray.py @@ -1291,12 +1291,17 @@ def _to_backend(self, backend: Backend) -> Self: backend=backend, ) - def _is_equal_to(self, other, index_dtype, numpyarray): + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: return ( - self.fields == other.fields - and len(self.contents) == len(other.contents) + self._is_equal_to_generic(other, all_parameters) + and self.is_tuple == other.is_tuple + and set(self.fields) == set(other.fields) and all( - self.contents[i].is_equal_to(other.contents[i], index_dtype, numpyarray) - for i in range(len(self.contents)) + content._is_equal_to( + other.content(field), index_dtype, numpyarray, all_parameters + ) + for field, content in zip(self.fields, self._contents) ) ) diff --git a/src/awkward/contents/regulararray.py b/src/awkward/contents/regulararray.py index 33c9b43d32..eed2ae428e 100644 --- a/src/awkward/contents/regulararray.py +++ b/src/awkward/contents/regulararray.py @@ -1527,7 +1527,17 @@ def _to_backend(self, backend: Backend) -> Self: content, self._size, zeros_length=self._length, parameters=self._parameters ) - def _is_equal_to(self, other, index_dtype, numpyarray): - return self._size == other.size and self._content.is_equal_to( - other._content, index_dtype, numpyarray + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters) + and ( + self._size is unknown_length + or other.size is unknown_length + or self._size == other.size + ) + and self._content._is_equal_to( + other._content, index_dtype, numpyarray, all_parameters + ) ) diff --git a/src/awkward/contents/unionarray.py b/src/awkward/contents/unionarray.py index 088b52050d..aa81b495da 100644 --- a/src/awkward/contents/unionarray.py +++ b/src/awkward/contents/unionarray.py @@ -1655,13 +1655,16 @@ def _to_backend(self, backend: Backend) -> Self: parameters=self._parameters, ) - def _is_equal_to(self, other, index_dtype, numpyarray): + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: return ( - self.tags.is_equal_to(other.tags) - and self.index.is_equal_to(other.index, index_dtype, numpyarray) - and len(self.contents) == len(other.contents) + self._is_equal_to_generic(other, all_parameters) + and self._tags.is_equal_to(other.tags) + and self._index.is_equal_to(other.index, index_dtype, numpyarray) + and len(self._contents) == len(other.contents) and all( - self.contents[i].is_equal_to(other.contents[i], index_dtype, numpyarray) - for i in range(len(self.contents)) + content.is_equal_to(other.contents[i], index_dtype, numpyarray) + for i, content in enumerate(self.contents) ) ) diff --git a/src/awkward/contents/unmaskedarray.py b/src/awkward/contents/unmaskedarray.py index fbb6eb35f2..2f560827ad 100644 --- a/src/awkward/contents/unmaskedarray.py +++ b/src/awkward/contents/unmaskedarray.py @@ -584,5 +584,9 @@ def _to_backend(self, backend: Backend) -> Self: content = self._content.to_backend(backend) return UnmaskedArray(content, parameters=self._parameters) - def _is_equal_to(self, other, index_dtype, numpyarray): - return self.content.is_equal_to(other.content, index_dtype, numpyarray) + def _is_equal_to( + self, other: Self, index_dtype: bool, numpyarray: bool, all_parameters: bool + ) -> bool: + return self._content._is_equal_to( + other.content, index_dtype, numpyarray, all_parameters + ) diff --git a/src/awkward/forms/bitmaskedform.py b/src/awkward/forms/bitmaskedform.py index 089f73c587..b875ee16d5 100644 --- a/src/awkward/forms/bitmaskedform.py +++ b/src/awkward/forms/bitmaskedform.py @@ -7,8 +7,7 @@ import awkward as ak from awkward._meta.bitmaskedmeta import BitMaskedMeta from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._parameters import type_parameters_equal -from awkward._typing import DType, Iterator, Self, final +from awkward._typing import Any, DType, Iterator, Self, final from awkward._util import UNSET from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype @@ -150,19 +149,6 @@ def type(self): self._content.type, parameters=self._parameters ).simplify_option_union() - def __eq__(self, other): - if isinstance(other, BitMaskedForm): - return ( - self._form_key == other._form_key - and self._mask == other._mask - and self._valid_when == other._valid_when - and self._lsb_order == other._lsb_order - and type_parameters_equal(self._parameters, other._parameters) - and self._content == other._content - ) - else: - return False - def _columns(self, path, output, list_indicator): self._content._columns(path, output, list_indicator) @@ -215,3 +201,12 @@ def _expected_from_buffers( yield (getkey(self, "mask"), index_to_dtype[self._mask]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters, form_key) + and self._mask == other._mask + and self._valid_when == other._valid_when + and self._lsb_order == other._lsb_order + and self._content._is_equal_to(other._content, all_parameters, form_key) + ) diff --git a/src/awkward/forms/bytemaskedform.py b/src/awkward/forms/bytemaskedform.py index b9ec5e4705..1b40caf085 100644 --- a/src/awkward/forms/bytemaskedform.py +++ b/src/awkward/forms/bytemaskedform.py @@ -7,8 +7,7 @@ import awkward as ak from awkward._meta.bytemaskedmeta import ByteMaskedMeta from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._parameters import type_parameters_equal -from awkward._typing import DType, Iterator, Self, final +from awkward._typing import Any, DType, Iterator, Self, final from awkward._util import UNSET from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype @@ -134,18 +133,6 @@ def type(self): self._content.type, parameters=self._parameters ).simplify_option_union() - def __eq__(self, other): - if isinstance(other, ByteMaskedForm): - return ( - self._form_key == other._form_key - and self._mask == other._mask - and self._valid_when == other._valid_when - and type_parameters_equal(self._parameters, other._parameters) - and self._content == other._content - ) - else: - return False - def _columns(self, path, output, list_indicator): self._content._columns(path, output, list_indicator) @@ -185,3 +172,11 @@ def _expected_from_buffers( yield (getkey(self, "mask"), index_to_dtype[self._mask]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters, form_key) + and self._mask == other._mask + and self._valid_when == other._valid_when + and self._content._is_equal_to(other._content, all_parameters, form_key) + ) diff --git a/src/awkward/forms/emptyform.py b/src/awkward/forms/emptyform.py index 026b531d92..2371de4fb8 100644 --- a/src/awkward/forms/emptyform.py +++ b/src/awkward/forms/emptyform.py @@ -12,7 +12,7 @@ from awkward._nplikes.shape import ShapeItem from awkward._typing import DType, Iterator, Self, final from awkward._util import UNSET, Sentinel -from awkward.forms.form import Form, JSONMapping, _SpecifierMatcher +from awkward.forms.form import Any, Form, JSONMapping, _SpecifierMatcher __all__ = ("EmptyForm",) @@ -59,9 +59,6 @@ def _to_dict_part(self, verbose, toplevel): def type(self): return ak.types.UnknownType() - def __eq__(self, other) -> bool: - return isinstance(other, EmptyForm) and self._form_key == other._form_key - def to_NumpyForm(self, *args, **kwargs): def legacy_impl(dtype): deprecate( @@ -125,3 +122,8 @@ def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool ) -> Iterator[tuple[str, DType]]: yield from () + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return isinstance(other, type(self)) and not ( + form_key and self._form_key != other._form_key + ) diff --git a/src/awkward/forms/form.py b/src/awkward/forms/form.py index 7863a31f52..18f16a0198 100644 --- a/src/awkward/forms/form.py +++ b/src/awkward/forms/form.py @@ -16,8 +16,13 @@ from awkward._meta.meta import Meta from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import ShapeItem, unknown_length -from awkward._parameters import parameters_union +from awkward._parameters import ( + parameters_are_equal, + parameters_union, + type_parameters_equal, +) from awkward._typing import ( + Any, DType, Final, Iterator, @@ -640,3 +645,26 @@ def expected_from_buffers( """ getkey = regularize_buffer_key(buffer_key) return dict(self._expected_from_buffers(getkey, recursive)) + + def is_equal_to( + self, other: Any, *, all_parameters: bool = False, form_key: bool = False + ) -> bool: + return self._is_equal_to(other, all_parameters, form_key) + + __eq__ = is_equal_to + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + raise NotImplementedError + + def _is_equal_to_generic( + self, other: Any, all_parameters: bool, form_key: bool + ) -> bool: + compare_parameters = ( + parameters_are_equal if all_parameters else type_parameters_equal + ) + + return ( + isinstance(other, type(self)) + and not (form_key and self._form_key != other._form_key) + and compare_parameters(self._parameters, other._parameters) + ) diff --git a/src/awkward/forms/indexedform.py b/src/awkward/forms/indexedform.py index a04b38a556..e9936b4912 100644 --- a/src/awkward/forms/indexedform.py +++ b/src/awkward/forms/indexedform.py @@ -7,8 +7,10 @@ import awkward as ak from awkward._meta.indexedmeta import IndexedMeta from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._parameters import parameters_union, type_parameters_equal -from awkward._typing import DType, Iterator, Self, final +from awkward._parameters import ( + parameters_union, +) +from awkward._typing import Any, DType, Iterator, Self, final from awkward._util import UNSET from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype @@ -138,17 +140,6 @@ def type(self): return out - def __eq__(self, other): - if isinstance(other, IndexedForm): - return ( - self._form_key == other._form_key - and self._index == other._index - and type_parameters_equal(self._parameters, other._parameters) - and self._content == other._content - ) - else: - return False - def _columns(self, path, output, list_indicator): self._content._columns(path, output, list_indicator) @@ -186,3 +177,10 @@ def _expected_from_buffers( yield (getkey(self, "index"), index_to_dtype[self._index]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters, form_key) + and self._index == other._index + and self._content._is_equal_to(other._content, all_parameters, form_key) + ) diff --git a/src/awkward/forms/indexedoptionform.py b/src/awkward/forms/indexedoptionform.py index c0c329349c..6d54be5e0c 100644 --- a/src/awkward/forms/indexedoptionform.py +++ b/src/awkward/forms/indexedoptionform.py @@ -7,8 +7,10 @@ import awkward as ak from awkward._meta.indexedoptionmeta import IndexedOptionMeta from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._parameters import parameters_union, type_parameters_equal -from awkward._typing import DType, Iterator, Self, final +from awkward._parameters import ( + parameters_union, +) +from awkward._typing import Any, DType, Iterator, Self, final from awkward._util import UNSET from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype @@ -119,17 +121,6 @@ def type(self): parameters=parameters, ).simplify_option_union() - def __eq__(self, other): - if isinstance(other, IndexedOptionForm): - return ( - self._form_key == other._form_key - and self._index == other._index - and type_parameters_equal(self._parameters, other._parameters) - and self._content == other._content - ) - else: - return False - def _columns(self, path, output, list_indicator): self._content._columns(path, output, list_indicator) @@ -167,3 +158,10 @@ def _expected_from_buffers( yield (getkey(self, "index"), index_to_dtype[self._index]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters, form_key) + and self._index == other._index + and self._content._is_equal_to(other._content, all_parameters, form_key) + ) diff --git a/src/awkward/forms/listform.py b/src/awkward/forms/listform.py index e5f8ff37db..6ed46d5d41 100644 --- a/src/awkward/forms/listform.py +++ b/src/awkward/forms/listform.py @@ -7,8 +7,7 @@ import awkward as ak from awkward._meta.listmeta import ListMeta from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._parameters import type_parameters_equal -from awkward._typing import DType, Iterator, Self, final +from awkward._typing import Any, DType, Iterator, Self, final from awkward._util import UNSET from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype @@ -122,18 +121,6 @@ def type(self): parameters=self._parameters, ) - def __eq__(self, other): - if isinstance(other, ListForm): - return ( - self._form_key == other._form_key - and self._starts == other._starts - and self._stops == other._stops - and type_parameters_equal(self._parameters, other._parameters) - and self._content == other._content - ) - else: - return False - def _columns(self, path, output, list_indicator): if ( self.parameter("__array__") not in ("string", "bytestring") @@ -182,3 +169,11 @@ def _expected_from_buffers( yield (getkey(self, "stops"), index_to_dtype[self._stops]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters, form_key) + and self._starts == other._starts + and self._stops == other._stops + and self._content._is_equal_to(other._content, all_parameters, form_key) + ) diff --git a/src/awkward/forms/listoffsetform.py b/src/awkward/forms/listoffsetform.py index ceb15dcd0b..75a495b666 100644 --- a/src/awkward/forms/listoffsetform.py +++ b/src/awkward/forms/listoffsetform.py @@ -7,8 +7,8 @@ import awkward as ak from awkward._meta.listoffsetmeta import ListOffsetMeta from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._parameters import type_parameters_equal from awkward._typing import ( + Any, DType, Iterator, JSONMapping, @@ -89,21 +89,7 @@ def _to_dict_part(self, verbose, toplevel): @property def type(self): - return ak.types.ListType( - self._content.type, - parameters=self._parameters, - ) - - def __eq__(self, other): - if isinstance(other, ListOffsetForm): - return ( - self._form_key == other._form_key - and self._offsets == other._offsets - and type_parameters_equal(self._parameters, other._parameters) - and self._content == other._content - ) - else: - return False + return ak.types.ListType(self._content.type, parameters=self._parameters) def _columns(self, path, output, list_indicator): if ( @@ -150,3 +136,10 @@ def _expected_from_buffers( yield (getkey(self, "offsets"), index_to_dtype[self._offsets]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters, form_key) + and self._offsets == other._offsets + and self._content._is_equal_to(other._content, all_parameters, form_key) + ) diff --git a/src/awkward/forms/numpyform.py b/src/awkward/forms/numpyform.py index f36b17e44d..3af48d5e19 100644 --- a/src/awkward/forms/numpyform.py +++ b/src/awkward/forms/numpyform.py @@ -9,8 +9,7 @@ from awkward._meta.numpymeta import NumpyMeta from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import unknown_length -from awkward._parameters import type_parameters_equal -from awkward._typing import TYPE_CHECKING, DType, JSONMapping, Self, final +from awkward._typing import TYPE_CHECKING, Any, DType, JSONMapping, Self, final from awkward._util import UNSET, Sentinel from awkward.forms.form import Form, _SpecifierMatcher @@ -168,17 +167,6 @@ def type(self): return out - def __eq__(self, other): - if isinstance(other, NumpyForm): - return ( - self._form_key == other._form_key - and self._primitive == other._primitive - and self._inner_shape == other._inner_shape - and type_parameters_equal(self._parameters, other._parameters) - ) - else: - return False - def to_RegularForm(self) -> RegularForm | NumpyForm: out: RegularForm | NumpyForm = NumpyForm( self._primitive, (), parameters=None, form_key=None @@ -270,3 +258,8 @@ def _expected_from_buffers( from awkward.types.numpytype import primitive_to_dtype yield (getkey(self, "data"), primitive_to_dtype(self.primitive)) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return self._is_equal_to_generic(other, all_parameters, form_key) and ( + self._primitive == other._primitive + ) diff --git a/src/awkward/forms/recordform.py b/src/awkward/forms/recordform.py index 33ec41564c..04f4d66a9d 100644 --- a/src/awkward/forms/recordform.py +++ b/src/awkward/forms/recordform.py @@ -7,8 +7,7 @@ import awkward as ak from awkward._meta.recordmeta import RecordMeta from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._parameters import type_parameters_equal -from awkward._typing import DType, Self, final +from awkward._typing import Any, DType, Self, final from awkward._util import UNSET from awkward.forms.form import Form, _SpecifierMatcher @@ -114,25 +113,6 @@ def type(self): parameters=self._parameters, ) - def __eq__(self, other): - if isinstance(other, RecordForm): - if ( - self._form_key == other._form_key - and self.is_tuple == other.is_tuple - and len(self._contents) == len(other._contents) - and type_parameters_equal(self._parameters, other._parameters) - ): - if self.is_tuple: - return self._contents == other._contents - else: - return dict(zip(self._fields, self._contents)) == dict( - zip(other._fields, other._contents) - ) - else: - return False - else: - return False - def _columns(self, path, output, list_indicator): for content, field in zip(self._contents, self.fields): content._columns((*path, field), output, list_indicator) @@ -197,3 +177,17 @@ def _expected_from_buffers( if recursive: for content in self._contents: yield from content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + computed_fields_set = set(self.fields) + + return ( + self._is_equal_to_generic(other, all_parameters, form_key) + and self.is_tuple == other.is_tuple + and len(self._contents) == len(other._contents) + and all(f in computed_fields_set for f in other.fields) + and all( + content._is_equal_to(other.content(field), all_parameters, form_key) + for field, content in zip(self.fields, self._contents) + ) + ) diff --git a/src/awkward/forms/regularform.py b/src/awkward/forms/regularform.py index 12f0bf5ed8..10596bad2a 100644 --- a/src/awkward/forms/regularform.py +++ b/src/awkward/forms/regularform.py @@ -8,9 +8,8 @@ from awkward._meta.regularmeta import RegularMeta from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import unknown_length -from awkward._parameters import type_parameters_equal from awkward._regularize import is_integer -from awkward._typing import DType, Self, final +from awkward._typing import Any, DType, Self, final from awkward._util import UNSET from awkward.forms.form import Form, _SpecifierMatcher @@ -83,17 +82,6 @@ def type(self): parameters=self._parameters, ) - def __eq__(self, other): - if isinstance(other, RegularForm): - return ( - self._form_key == other._form_key - and self._size == other._size - and type_parameters_equal(self._parameters, other._parameters) - and self._content == other._content - ) - else: - return False - def _columns(self, path, output, list_indicator): if ( self.parameter("__array__") not in ("string", "bytestring") @@ -138,3 +126,10 @@ def _expected_from_buffers( ) -> Iterator[tuple[str, DType]]: if recursive: yield from self._content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters, form_key) + and (self._size == other._size) + and self._content._is_equal_to(other._content, all_parameters, form_key) + ) diff --git a/src/awkward/forms/unionform.py b/src/awkward/forms/unionform.py index a73a14f14a..fd6d439acc 100644 --- a/src/awkward/forms/unionform.py +++ b/src/awkward/forms/unionform.py @@ -3,12 +3,12 @@ from __future__ import annotations from collections.abc import Callable, Iterable +from itertools import permutations import awkward as ak from awkward._meta.unionmeta import UnionMeta from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._parameters import type_parameters_equal -from awkward._typing import DType, Iterator, Self, final +from awkward._typing import Any, DType, Iterator, Self, final from awkward._util import UNSET from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype @@ -138,19 +138,6 @@ def type(self): parameters=self._parameters, ) - def __eq__(self, other): - if ( - isinstance(other, UnionForm) - and self._form_key == other._form_key - and self._tags == other._tags - and self._index == other._index - and len(self._contents) == len(other._contents) - and type_parameters_equal(self._parameters, other._parameters) - ): - return self._contents == other._contents - - return False - def _columns(self, path, output, list_indicator): for content, field in zip(self._contents, self.fields): content._columns((*path, field), output, list_indicator) @@ -209,3 +196,18 @@ def _expected_from_buffers( if recursive: for content in self._contents: yield from content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return ( + self._is_equal_to_generic(other, all_parameters, form_key) + and self._tags == other.tags + and self._index == other.index + and len(self._contents) == len(other.contents) + and any( + all( + x._is_equal_to(y, all_parameters, form_key) + for x, y in zip(self._contents, c) + ) + for c in permutations(other.contents) + ) + ) diff --git a/src/awkward/forms/unmaskedform.py b/src/awkward/forms/unmaskedform.py index cbfd96918a..99dd441c06 100644 --- a/src/awkward/forms/unmaskedform.py +++ b/src/awkward/forms/unmaskedform.py @@ -7,8 +7,10 @@ import awkward as ak from awkward._meta.unmaskedmeta import UnmaskedMeta from awkward._nplikes.numpy_like import NumpyMetadata -from awkward._parameters import parameters_union, type_parameters_equal -from awkward._typing import DType, Self, final +from awkward._parameters import ( + parameters_union, +) +from awkward._typing import Any, DType, Self, final from awkward._util import UNSET from awkward.forms.form import Form, _SpecifierMatcher @@ -95,16 +97,6 @@ def type(self): parameters=self._parameters, ).simplify_option_union() - def __eq__(self, other): - if isinstance(other, UnmaskedForm): - return ( - self._form_key == other._form_key - and type_parameters_equal(self._parameters, other._parameters) - and self._content == other._content - ) - else: - return False - def _columns(self, path, output, list_indicator): self._content._columns(path, output, list_indicator) @@ -141,3 +133,8 @@ def _expected_from_buffers( ) -> Iterator[tuple[str, DType]]: if recursive: yield from self._content._expected_from_buffers(getkey, recursive) + + def _is_equal_to(self, other: Any, all_parameters: bool, form_key: bool) -> bool: + return self._is_equal_to_generic( + other, all_parameters, form_key + ) and self._content._is_equal_to(other._content, all_parameters, form_key) diff --git a/src/awkward/types/recordtype.py b/src/awkward/types/recordtype.py index ff7329e21a..3cce687e40 100644 --- a/src/awkward/types/recordtype.py +++ b/src/awkward/types/recordtype.py @@ -4,7 +4,6 @@ import json from collections.abc import Iterable, Mapping -from itertools import permutations import awkward as ak import awkward._prettyprint @@ -212,16 +211,11 @@ def _is_equal_to(self, other: Any, all_parameters: bool) -> bool: if set(self._fields) != set(other._fields): return False - self_contents = [self.content(f) for f in self._fields] - other_contents = [other.content(f) for f in other._fields] - - return any( - all( - this._is_equal_to(that, all_parameters) - for this, that in zip(self_contents, contents) - ) - for contents in permutations(other_contents) + return all( + content._is_equal_to(other.content(field), all_parameters) + for field, content in zip(self._fields, self._contents) ) + # Mixed else: return False diff --git a/tests/test_2368_type_is_equal.py b/tests/test_2368_type_is_equal.py index 4ff3ed8ffe..b0222fc240 100644 --- a/tests/test_2368_type_is_equal.py +++ b/tests/test_2368_type_is_equal.py @@ -82,7 +82,7 @@ def test_record_tuple(): assert record_type != tuple_type -def test_record_mixed(): +def test_record_permuted(): record = ak.types.from_datashape("10 * var * {x: int64, y: int32}") permutation = ak.types.from_datashape("10 * var * {y: int64, x: int32}") - assert record == permutation + assert record != permutation