diff --git a/src/awkward/contents/bitmaskedarray.py b/src/awkward/contents/bitmaskedarray.py index 6790ac7bc0..047a97ea70 100644 --- a/src/awkward/contents/bitmaskedarray.py +++ b/src/awkward/contents/bitmaskedarray.py @@ -15,9 +15,6 @@ from awkward._nplikes.numpy_like import IndexType, NumpyMetadata from awkward._nplikes.shape import ShapeItem, unknown_length from awkward._nplikes.typetracer import MaybeNone, TypeTracer -from awkward._parameters import ( - type_parameters_equal, -) from awkward._regularize import is_integer, is_integer_like from awkward._slicing import NO_HEAD from awkward._typing import ( @@ -577,11 +574,9 @@ def _mergeable_next(self, other: Content, mergebool: bool) -> bool: # Is the other content is an identity, or a union? if other.is_identity_like or other.is_union: return True - # We can only combine option types whose array-record parameters agree + # Is the other array indexed or optional? elif other.is_option or other.is_indexed: - return self._content._mergeable_next( - other.content, mergebool - ) and type_parameters_equal(self._parameters, other._parameters) + return self._content._mergeable_next(other.content, mergebool) else: return self._content._mergeable_next(other, mergebool) diff --git a/src/awkward/contents/bytemaskedarray.py b/src/awkward/contents/bytemaskedarray.py index 5738ae3803..84742101d3 100644 --- a/src/awkward/contents/bytemaskedarray.py +++ b/src/awkward/contents/bytemaskedarray.py @@ -18,7 +18,6 @@ from awkward._nplikes.typetracer import MaybeNone, TypeTracer from awkward._parameters import ( parameters_intersect, - type_parameters_equal, ) from awkward._regularize import is_integer_like from awkward._slicing import NO_HEAD @@ -721,11 +720,9 @@ def _mergeable_next(self, other: Content, mergebool: bool) -> bool: # Is the other content is an identity, or a union? if other.is_identity_like or other.is_union: return True - # We can only combine option types whose array-record parameters agree + # Is the other array indexed or optional? elif other.is_option or other.is_indexed: - return self._content._mergeable_next( - other.content, mergebool - ) and type_parameters_equal(self._parameters, other._parameters) + return self._content._mergeable_next(other.content, mergebool) else: return self._content._mergeable_next(other, mergebool) diff --git a/src/awkward/contents/content.py b/src/awkward/contents/content.py index 3812fec414..7315a32ff8 100644 --- a/src/awkward/contents/content.py +++ b/src/awkward/contents/content.py @@ -23,7 +23,6 @@ from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpy_like import IndexType, NumpyMetadata from awkward._nplikes.shape import ShapeItem, unknown_length -from awkward._nplikes.typetracer import TypeTracer from awkward._parameters import ( parameters_are_equal, type_parameters_equal, @@ -762,41 +761,22 @@ def _merging_strategy( head = [self] tail = [] - i = 0 - - while i < len(others): - other = others[i] - if isinstance( - other, - ( - ak.contents.IndexedArray, - ak.contents.IndexedOptionArray, - ak.contents.ByteMaskedArray, - ak.contents.BitMaskedArray, - ak.contents.UnmaskedArray, - ak.contents.UnionArray, - ), - ): + + it_others = iter(others) + + for other in it_others: + if other.is_indexed or other.is_option or other.is_union: + tail.append(other) + tail.extend(it_others) break else: head.append(other) - i = i + 1 - - while i < len(others): - tail.append(others[i]) - i = i + 1 - - if any(isinstance(x.backend.nplike, TypeTracer) for x in head + tail): - head = [ - x if isinstance(x.backend.nplike, TypeTracer) else x.to_typetracer() - for x in head - ] - tail = [ - x if isinstance(x.backend.nplike, TypeTracer) else x.to_typetracer() - for x in tail - ] - - return (head, tail) + + assert not any(x.backend.nplike.known_data for x in head + tail) or all( + x.backend.nplike.known_data for x in head + tail + ) + + return head, tail def _local_index(self, axis: int, depth: int): raise NotImplementedError diff --git a/src/awkward/contents/indexedarray.py b/src/awkward/contents/indexedarray.py index ee5cb28c10..7f30634e52 100644 --- a/src/awkward/contents/indexedarray.py +++ b/src/awkward/contents/indexedarray.py @@ -17,7 +17,6 @@ from awkward._parameters import ( parameters_intersect, parameters_union, - type_parameters_equal, ) from awkward._regularize import is_integer_like from awkward._slicing import NO_HEAD @@ -500,11 +499,9 @@ def _mergeable_next(self, other: Content, mergebool: bool) -> bool: # Is the other content is an identity, or a union? if other.is_identity_like or other.is_union: return True - # We can only combine option/indexed types whose array-record parameters agree + # Is the other array indexed or optional? elif other.is_option or other.is_indexed: - return self._content._mergeable_next( - other.content, mergebool - ) and type_parameters_equal(self._parameters, other._parameters) + return self._content._mergeable_next(other.content, mergebool) else: return self._content._mergeable_next(other, mergebool) @@ -517,32 +514,38 @@ def _merging_strategy(self, others): head = [self] tail = [] - i = 0 - while i < len(others): - other = others[i] + it_others = iter(others) + for other in it_others: if isinstance(other, ak.contents.UnionArray): + tail.append(other) + tail.extend(it_others) break else: head.append(other) - i = i + 1 - while i < len(others): - tail.append(others[i]) - i = i + 1 - - if any(isinstance(x.backend.nplike, TypeTracer) for x in head + tail): - head = [ - x if isinstance(x.backend.nplike, TypeTracer) else x.to_typetracer() - for x in head - ] - tail = [ - x if isinstance(x.backend.nplike, TypeTracer) else x.to_typetracer() - for x in tail - ] + if any(x.backend.nplike.known_data for x in head + tail) and not all( + x.backend.nplike.known_data for x in head + tail + ): + raise RuntimeError - return (head, tail) + return head, tail def _reverse_merge(self, other): + if isinstance(other, ak.contents.EmptyArray): + return self + + # FIXME: support categorical-categorical merging + if ( + other.is_indexed + and other.parameter("__array__") + == self.parameter("__array__") + == "categorical" + ): + raise NotImplementedError( + "merging categorical arrays is currently not implemented. " + "Use `ak.enforce_type` to drop the categorical type and use general merging." + ) + theirlength = other.length mylength = self.length index = ak.index.Index64.empty( @@ -663,9 +666,24 @@ def _mergemany(self, others: Sequence[Content]) -> Content: contentlength_so_far += array.length length_so_far += array.length + # Categoricals may only survive if all contents are categorical + if ( + parameters is not None + and parameters.get("__array__") == "categorical" + ): + parameters = {**parameters} + del parameters["__array__"] + tail_contents = contents[1:] nextcontent = contents[0]._mergemany(tail_contents) + # FIXME: support categorical merging? + if parameters is not None and parameters.get("__array__") == "categorical": + raise NotImplementedError( + "merging categorical arrays is currently not implemented. " + "Use `ak.enforce_type` to drop the categorical type and use general merging." + ) + # Options win out! if any(x.is_option for x in head): next = ak.contents.IndexedOptionArray( diff --git a/src/awkward/contents/indexedoptionarray.py b/src/awkward/contents/indexedoptionarray.py index 396aece3a1..6373e61161 100644 --- a/src/awkward/contents/indexedoptionarray.py +++ b/src/awkward/contents/indexedoptionarray.py @@ -17,7 +17,6 @@ from awkward._parameters import ( parameters_intersect, parameters_union, - type_parameters_equal, ) from awkward._regularize import is_integer_like from awkward._slicing import NO_HEAD @@ -633,11 +632,9 @@ def _mergeable_next(self, other: Content, mergebool: bool) -> bool: # Is the other content is an identity, or a union? if other.is_identity_like or other.is_union: return True - # We can only combine option/indexed types whose array-record parameters agree + # Is the other array indexed or optional? elif other.is_option or other.is_indexed: - return self._content._mergeable_next( - other.content, mergebool - ) and type_parameters_equal(self._parameters, other._parameters) + return self._content._mergeable_next(other.content, mergebool) else: return self._content._mergeable_next(other, mergebool) @@ -650,35 +647,38 @@ def _merging_strategy(self, others): head = [self] tail = [] - i = 0 - while i < len(others): - other = others[i] + it_others = iter(others) + for other in it_others: if isinstance(other, ak.contents.UnionArray): + tail.append(other) + tail.extend(it_others) break else: head.append(other) - i = i + 1 - while i < len(others): - tail.append(others[i]) - i = i + 1 - - if any(isinstance(x.backend.nplike, TypeTracer) for x in head + tail): - head = [ - x if isinstance(x.backend.nplike, TypeTracer) else x.to_typetracer() - for x in head - ] - tail = [ - x if isinstance(x.backend.nplike, TypeTracer) else x.to_typetracer() - for x in tail - ] + if any(x.backend.nplike.known_data for x in head + tail) and not all( + x.backend.nplike.known_data for x in head + tail + ): + raise RuntimeError - return (head, tail) + return head, tail def _reverse_merge(self, other): if isinstance(other, ak.contents.EmptyArray): return self + # FIXME: support categorical-categorical merging + if ( + other.is_indexed + and other.parameter("__array__") + == self.parameter("__array__") + == "categorical" + ): + raise NotImplementedError( + "merging categorical arrays is currently not implemented. " + "Use `ak.enforce_type` to drop the categorical type and use general merging." + ) + theirlength = other.length mylength = self.length index = ak.index.Index64.empty( @@ -688,6 +688,7 @@ def _reverse_merge(self, other): content = other._mergemany([self._content]) + # Fill index::0→theirlength with arange(theirlength) assert index.nplike is self._backend.index_nplike self._backend.maybe_kernel_error( self._backend["awkward_IndexedArray_fill_count", index.dtype.type]( @@ -697,30 +698,27 @@ def _reverse_merge(self, other): 0, ) ) - reinterpreted_index = ak.index.Index( - self._backend.index_nplike.asarray(self.index.data), - nplike=self._backend.index_nplike, - ) + # Fill index::theirlength->end with self.index[:mylength]+theirlength assert ( index.nplike is self._backend.index_nplike - and reinterpreted_index.nplike is self._backend.index_nplike + and self.index.nplike is self._backend.index_nplike ) self._backend.maybe_kernel_error( self._backend[ "awkward_IndexedArray_fill", index.dtype.type, - reinterpreted_index.dtype.type, + self.index.dtype.type, ]( index.data, theirlength, - reinterpreted_index.data, + self.index.data, mylength, theirlength, ) ) - # We can directly merge with other options, but we must merge parameters - if other.is_option: + # We can directly merge with other options and indexed types, but we must merge parameters + if other.is_option or other.is_indexed: parameters = parameters_union(self._parameters, other._parameters) # Otherwise, this option parameters win out else: @@ -806,12 +804,27 @@ def _mergemany(self, others: Sequence[Content]) -> Content: length_so_far += array.length + # Categoricals may only survive if all contents are categorical + if ( + parameters is not None + and parameters.get("__array__") == "categorical" + ): + parameters = {**parameters} + del parameters["__array__"] + tail_contents = contents[1:] nextcontent = contents[0]._mergemany(tail_contents) next = ak.contents.IndexedOptionArray( nextindex, nextcontent, parameters=parameters ) + # FIXME: support categorical merging? + if parameters is not None and parameters.get("__array__") == "categorical": + raise NotImplementedError( + "merging categorical arrays is currently not implemented. " + "Use `ak.enforce_type` to drop the categorical type and use general merging." + ) + if len(tail) == 0: return next diff --git a/src/awkward/contents/listarray.py b/src/awkward/contents/listarray.py index d4751cef9d..da0ef351bf 100644 --- a/src/awkward/contents/listarray.py +++ b/src/awkward/contents/listarray.py @@ -1076,8 +1076,8 @@ def _mergeable_next(self, other: Content, mergebool: bool) -> bool: # Is the other content is an identity, or a union? if other.is_identity_like or other.is_union: return True - # Check against option contents - elif other.is_option or other.is_indexed: + # Is the other array indexed or optional? + elif other.is_indexed or other.is_option: return self._mergeable_next(other.content, mergebool) # Otherwise, do the parameters match? If not, we can't merge. elif not type_parameters_equal(self._parameters, other._parameters): @@ -1124,6 +1124,8 @@ def _mergemany(self, others: Sequence[Content]) -> Content: ), ): contents.append(array.content) + elif array.is_numpy: + contents.append(array.to_RegularArray().content) else: raise ValueError( "cannot merge " @@ -1143,6 +1145,11 @@ def _mergemany(self, others: Sequence[Content]) -> Content: length_so_far = 0 for array in head: + # We need contiguous content, so let's just convert to RegularArray + # immediately. + if array.is_numpy: + array = array.to_RegularArray() + if isinstance( array, ( @@ -1218,6 +1225,9 @@ def _mergemany(self, others: Sequence[Content]) -> Content: elif isinstance(array, ak.contents.EmptyArray): pass + else: + raise AssertionError + next = ak.contents.ListArray( nextstarts, nextstops, nextcontent, parameters=parameters ) diff --git a/src/awkward/contents/listoffsetarray.py b/src/awkward/contents/listoffsetarray.py index 1b54e108a9..e35b32320e 100644 --- a/src/awkward/contents/listoffsetarray.py +++ b/src/awkward/contents/listoffsetarray.py @@ -785,8 +785,8 @@ def _mergeable_next(self, other: Content, mergebool: bool) -> bool: # Is the other content is an identity, or a union? if other.is_identity_like or other.is_union: return True - # Check against option contents - elif other.is_option or other.is_indexed: + # Is the other array indexed or optional? + elif other.is_indexed or other.is_option: return self._mergeable_next(other.content, mergebool) # Otherwise, do the parameters match? If not, we can't merge. elif not type_parameters_equal(self._parameters, other._parameters): diff --git a/src/awkward/contents/numpyarray.py b/src/awkward/contents/numpyarray.py index 5bf111d2db..0adba74ba4 100644 --- a/src/awkward/contents/numpyarray.py +++ b/src/awkward/contents/numpyarray.py @@ -451,8 +451,8 @@ def _mergeable_next(self, other: Content, mergebool: bool) -> bool: # Is the other content is an identity, or a union? if other.is_identity_like or other.is_union: return True - # Check against option contents - elif other.is_option or other.is_indexed: + # Is the other array indexed or optional? + elif other.is_indexed or other.is_option: return self._mergeable_next(other.content, mergebool) # Otherwise, do the parameters match? If not, we can't merge. elif not type_parameters_equal(self._parameters, other._parameters): diff --git a/src/awkward/contents/regulararray.py b/src/awkward/contents/regulararray.py index eed2ae428e..c5748d7f37 100644 --- a/src/awkward/contents/regulararray.py +++ b/src/awkward/contents/regulararray.py @@ -730,8 +730,8 @@ def _mergeable_next(self, other: Content, mergebool: bool) -> bool: # Is the other content is an identity, or a union? if other.is_identity_like or other.is_union: return True - # Check against option contents - elif other.is_option or other.is_indexed: + # Is the other array indexed or optional? + elif other.is_indexed or other.is_option: return self._mergeable_next(other.content, mergebool) # Otherwise, do the parameters match? If not, we can't merge. elif not type_parameters_equal(self._parameters, other._parameters): diff --git a/src/awkward/contents/unionarray.py b/src/awkward/contents/unionarray.py index aa81b495da..431312c21e 100644 --- a/src/awkward/contents/unionarray.py +++ b/src/awkward/contents/unionarray.py @@ -994,23 +994,15 @@ def _merging_strategy(self, others): "to merge this array with 'others', at least one other must be provided" ) - head = [self] + head = [self, *others] tail = [] - for i in range(len(others)): - head.append(others[i]) - - if any(isinstance(x.backend.nplike, TypeTracer) for x in head + tail): - head = [ - x if isinstance(x.backend.nplike, TypeTracer) else x.to_typetracer() - for x in head - ] - tail = [ - x if isinstance(x.backend.nplike, TypeTracer) else x.to_typetracer() - for x in tail - ] + if any(x.backend.nplike.known_data for x in head + tail) and not all( + x.backend.nplike.known_data for x in head + tail + ): + raise RuntimeError - return (head, tail) + return head, tail def _reverse_merge(self, other): theirlength = other.length diff --git a/src/awkward/contents/unmaskedarray.py b/src/awkward/contents/unmaskedarray.py index 2f560827ad..8afdc6af5d 100644 --- a/src/awkward/contents/unmaskedarray.py +++ b/src/awkward/contents/unmaskedarray.py @@ -18,7 +18,6 @@ from awkward._parameters import ( parameters_intersect, parameters_union, - type_parameters_equal, ) from awkward._regularize import is_integer_like from awkward._slicing import NO_HEAD @@ -127,7 +126,7 @@ def simplified(cls, content, *, parameters=None): parameters=parameters_union(content._parameters, parameters), ) elif content.is_indexed or content.is_option: - return content.copy( + return content.to_IndexedOptionArray64().copy( parameters=parameters_union(content._parameters, parameters) ) else: @@ -340,11 +339,9 @@ def _mergeable_next(self, other: Content, mergebool: bool) -> bool: # Is the other content is an identity, or a union? if other.is_identity_like or other.is_union: return True - # We can only combine option types whose array-record parameters agree + # Is the other array indexed or optional? elif other.is_option or other.is_indexed: - return self._mergeable_next( - other.content, mergebool - ) and type_parameters_equal(self._parameters, other._parameters) + return self._content._mergeable_next(other.content, mergebool) else: return self._content._mergeable_next(other, mergebool) diff --git a/src/awkward/operations/ak_concatenate.py b/src/awkward/operations/ak_concatenate.py index 18143b6840..c53a24bd19 100644 --- a/src/awkward/operations/ak_concatenate.py +++ b/src/awkward/operations/ak_concatenate.py @@ -2,17 +2,22 @@ from __future__ import annotations +from itertools import permutations + import awkward as ak from awkward._backends.dispatch import backend_of_obj from awkward._backends.numpy import NumpyBackend from awkward._dispatch import high_level_function +from awkward._do import mergeable from awkward._layout import HighLevelContext, ensure_same_backend, maybe_posaxis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import unknown_length +from awkward._parameters import type_parameters_equal from awkward._regularize import regularize_axis from awkward._typing import Sequence from awkward.contents import Content from awkward.operations.ak_fill_none import fill_none +from awkward.types.numpytype import primitive_to_dtype __all__ = ("concatenate",) @@ -332,3 +337,350 @@ def action(inputs, depth, backend, **kwargs): )[0] return ctx.wrap(out, highlevel=highlevel) + + +def _form_has_type(form, type_): + """ + Args: + form: content object + type_: low-level type object + + Returns True if the form satisfies the given type; otherwise False. + """ + if not type_parameters_equal(form._parameters, type_._parameters): + return False + + if form.is_unknown: + return isinstance(type_, ak.types.UnknownType) + elif form.is_option: + return isinstance(type_, ak.types.OptionType) and _form_has_type( + form.content, type_.content + ) + elif form.is_indexed: + return _form_has_type(form.content, type_) + elif form.is_regular: + return ( + isinstance(type_, ak.types.RegularType) + and ( + form.size is unknown_length + or type_.size is unknown_length + or form.size == type_.size + ) + and _form_has_type(form.content, type_.content) + ) + elif form.is_list: + return isinstance(type_, ak.types.ListType) and _form_has_type( + form.content, type_.content + ) + elif form.is_numpy: + for _ in range(form.purelist_depth - 1): + if not isinstance(type_, ak.types.RegularType): + return False + type_ = type_.content + return ( + isinstance(type_, ak.types.NumpyType) and form.primitive == type_.primitive + ) + elif form.is_record: + if ( + not isinstance(type_, ak.types.RecordType) + or type_.is_tuple != form.is_tuple + ): + return False + + if form.is_tuple: + return all( + _form_has_type(c, t) for c, t in zip(form.contents, type_.contents) + ) + else: + return (frozenset(form.fields) == frozenset(type_.fields)) and all( + _form_has_type(form.content(f), type_.content(f)) for f in type_.fields + ) + elif form.is_union: + if len(form.contents) != len(type_.contents): + return False + + for contents in permutations(form.contents): + if all( + _form_has_type(form, type_) + for form, type_ in zip(contents, type_.contents) + ): + return True + return False + else: + raise TypeError(form) + + +# This routine should not try to replicate the merge logic, +# but we can make use of assumptions w.r.t to what the merge will do. +# e.g., merging can add new unions, promote to options, change dtypes of NumPy arrays +def enforce_concatenated_form(layout, form): + # Merge invariant (drop known-ness) + if not layout.is_unknown and form.is_unknown: + raise AssertionError( + "merge result should never be of an unknown type unless the layout is unknown" + ) + # Unknowns become canonical forms + elif layout.is_unknown and not form.is_unknown: + return form.length_zero_array(highlevel=False).to_backend(layout.backend) + + ############## Unions ##################################################### + # Merge invariant (drop union) + elif layout.is_union and not form.is_union: + raise AssertionError("merge result should be a union if layout is a union") + # Add a union + elif not layout.is_union and form.is_union: + # Merge invariant (unions are i8-i64) + if not (form.tags == "i8" and form.index == "i64"): + raise AssertionError( + "merge result that forms a union should have i8 tags and i64 index" + ) + + # Non-categoricals can be merged into union + if ( + layout.is_indexed + and not layout.is_option + and layout.parameter("__array__") != "categorical" + ): + index = layout.index.to64() + # Take the content and drop the parameters (we're taking parameters from form!) + layout_to_merge = layout.content + # Otherwise, we move into the contents + else: + index = ak.index.Index64( + layout.backend.index_nplike.arange(layout.length, dtype=np.int64) + ) + layout_to_merge = layout + + type_ = layout_to_merge.form.type + + # First assume this type is exactly represented in the union. + # This won't hold true if any (and not all) of the contents are an option + # Or if there were mergeable (but non-equal type) pairs in the original + # concatenation that formed this union + union_has_exact_type = False + contents = [] + for content_form in form.contents: + if _form_has_type(content_form, type_): + contents.insert( + 0, enforce_concatenated_form(layout_to_merge, content_form) + ) + union_has_exact_type = True + else: + contents.append( + content_form.length_zero_array(highlevel=False).to_backend( + layout.backend + ) + ) + + # Otherwise, find anything we can merge with + if not union_has_exact_type: + contents.clear() + + for content_form in form.contents: + # TODO check forms mergeable + content_layout = content_form.length_zero_array( + highlevel=False + ).to_backend(layout.backend) + if mergeable(content_layout, layout_to_merge): + contents.insert( + 0, enforce_concatenated_form(layout_to_merge, content_form) + ) + else: + contents.append( + content_form.length_zero_array(highlevel=False).to_backend( + layout.backend + ) + ) + + return ak.contents.UnionArray( + ak.index.Index8( + layout.backend.index_nplike.zeros(layout.length, dtype=np.int8) + ), + index, + contents, + parameters=form._parameters, + ) + # Preserve union + elif layout.is_union and form.is_union: + # Merge invariant (unions are i8-i64) + if not (form.tags == "i8" and form.index == "i64"): + raise AssertionError( + "merge result that forms a union should have i8 tags and i64 index" + ) + if len(form.contents) < len(layout.contents): + raise AssertionError( + "merge result should only grow or preserve a union's cardinality" + ) + form_contents = [ + f.length_zero_array(highlevel=False).to_backend(layout.backend) + for f in form.contents + ] + form_indices = range(len(form_contents)) + for form_projection_indices in permutations(form_indices, len(layout.contents)): + if all( + mergeable(c, form_contents[i]) + for c, i in zip(layout.contents, form_projection_indices) + ): + break + else: + raise AssertionError( + "merge result should be mergeable against some permutation of the layout" + ) + + next_contents = [ + enforce_concatenated_form(c, form.contents[i]) + for c, i in zip(layout.contents, form_projection_indices) + ] + next_contents.extend( + [ + form_contents[i] + for i in (set(form_indices) - set(form_projection_indices)) + ] + ) + return ak.contents.UnionArray( + ak.index.Index8( + layout.backend.index_nplike.astype(layout.tags.data, np.int8) + ), + layout.index.to64(), + next_contents, + parameters=form._parameters, + ) + + ############## Options #################################################### + # Merge invariant (drop option) + elif layout.is_option and not form.is_option: + raise AssertionError("merge result should be an option if layout is an option") + # Add option + elif not layout.is_option and form.is_option: + return enforce_concatenated_form( + ak.contents.UnmaskedArray.simplified(layout), form + ) + # Preserve option + elif layout.is_option and form.is_option: + if isinstance(form, ak.forms.IndexedOptionForm): + if form.index != "i64": + raise AssertionError( + "IndexedOptionForm should have i64 for merge results" + ) + return layout.to_IndexedOptionArray64().copy( + content=enforce_concatenated_form(layout.content, form.content), + parameters=form._parameters, + ) + # Non IndexedOptionArray types require all merge candidates to have same form + elif isinstance( + form, + (ak.forms.ByteMaskedForm, ak.forms.BitMaskedForm, ak.forms.UnmaskedForm), + ): + return layout.copy( + content=enforce_concatenated_form(layout.content, form.content), + parameters=form._parameters, + ) + else: + raise AssertionError + + ############## Indexed #################################################### + # Merge invariant (drop indexed) + elif layout.is_indexed and not form.is_indexed: + raise AssertionError("merge result must be indexed if layout is indexed") + # Add index + elif not layout.is_indexed and form.is_indexed: + return ak.contents.IndexedArray( + ak.index.Index64(layout.backend.index_nplike.arange(layout.length)), + enforce_concatenated_form(layout, form.content), + parameters=form._parameters, + ) + # Preserve index + elif layout.is_indexed and form.is_indexed: + if form.index != "i64": + raise AssertionError("merge result must be i64") + return ak.contents.IndexedArray( + layout.index.to64(), + content=enforce_concatenated_form(layout.content, form.content), + parameters=form._parameters, + ) + + ############## NumPy ###################################################### + elif layout.is_numpy and form.is_numpy: + if layout.inner_shape != form.inner_shape: + raise AssertionError("layout must have same inner_shape as merge result") + + return ak.values_astype( + # HACK: drop parameters from type so that character arrays are supported + layout.copy(parameters=None), + to=primitive_to_dtype(form.primitive), + highlevel=False, + ).copy(parameters=form._parameters) + + ############## Lists ###################################################### + # Merge invariant (regular to numpy) + elif layout.is_regular and form.is_numpy: + raise AssertionError("layout cannot be regular for NumpyForm merge result") + # Merge invariant (ragged to regular) + elif not (layout.is_regular or layout.is_numpy) and form.is_regular: + raise AssertionError("merge result should be ragged if any input is ragged") + elif layout.is_numpy and form.is_list: + if len(layout.inner_shape) == 0: + raise AssertionError("layout must be at least 2D if merge result is a list") + return enforce_concatenated_form(layout.to_RegularArray(), form) + elif layout.is_regular and form.is_regular: + # regular → regular requires same size! + if layout.size != form.size: + raise AssertionError( + "RegularForm must have same size as layout for merge result" + ) + return layout.copy( + content=enforce_concatenated_form(layout.content, form.content), + parameters=form._parameters, + ) + elif layout.is_regular and form.is_list: + if isinstance(form, (ak.forms.ListOffsetForm, ak.forms.ListForm)): + return enforce_concatenated_form(layout.to_ListOffsetArray64(False), form) + else: + raise AssertionError + elif layout.is_list and form.is_list: + if isinstance(form, ak.forms.ListOffsetForm): + layout = layout.to_ListOffsetArray64(False) + return layout.copy( + content=enforce_concatenated_form(layout.content, form.content), + parameters=form._parameters, + ) + elif isinstance(form, ak.forms.ListForm): + if not (form.starts == "i64" and form.stops == "i64"): + raise TypeError("ListForm should have i64 for merge results") + return ak.contents.ListArray( + layout.starts.to64(), + layout.stops.to64(), + enforce_concatenated_form(layout.content, form.content), + parameters=form._parameters, + ) + else: + raise AssertionError + + ############## Records #################################################### + # Merge invariant (mix record-tuple) + elif layout.is_record and not form.is_record: + raise AssertionError("merge result should be a record if layout is a record") + # Merge invariant (mix record-tuple) + elif not layout.is_record and form.is_record: + raise AssertionError( + "layout result should be a record if merge result is a record" + ) + elif layout.is_record and form.is_record: + if frozenset(layout.fields) != frozenset(form.fields): + raise AssertionError("merge result and form must have matching fields") + elif layout.is_tuple != form.is_tuple: + raise AssertionError( + "merge result and form must both be tuple or record-like" + ) + return ak.contents.RecordArray( + [ + enforce_concatenated_form(layout.content(f), form.content(f)) + for f in layout.fields + ], + layout._fields, + length=layout.length, + parameters=form._parameters, + backend=layout.backend, + ) + else: + raise NotImplementedError diff --git a/tests/test_0093_simplify_uniontypes_and_optiontypes.py b/tests/test_0093_simplify_uniontypes_and_optiontypes.py index f65a64d454..987e8b3f7a 100644 --- a/tests/test_0093_simplify_uniontypes_and_optiontypes.py +++ b/tests/test_0093_simplify_uniontypes_and_optiontypes.py @@ -26,7 +26,10 @@ def test_numpyarray_merge(): == ak1._mergemany([ak2]).form ) assert ( - ak1[1:, :-1, ::-1].to_typetracer()._mergemany([ak2[1:, :-1, ::-1]]).form + ak1[1:, :-1, ::-1] + .to_typetracer() + ._mergemany([ak2[1:, :-1, ::-1].to_typetracer()]) + .form == ak1[1:, :-1, ::-1]._mergemany([ak2[1:, :-1, ::-1]]).form ) diff --git a/tests/test_0449_merge_many_arrays_in_one_pass.py b/tests/test_0449_merge_many_arrays_in_one_pass.py index 381898db07..d44db8bdfc 100644 --- a/tests/test_0449_merge_many_arrays_in_one_pass.py +++ b/tests/test_0449_merge_many_arrays_in_one_pass.py @@ -38,9 +38,9 @@ def test_numpyarray(): .to_typetracer() ._mergemany( [ - ak.contents.NumpyArray(two), - ak.contents.NumpyArray(three), - ak.contents.NumpyArray(four), + ak.contents.NumpyArray(two).to_typetracer(), + ak.contents.NumpyArray(three).to_typetracer(), + ak.contents.NumpyArray(four).to_typetracer(), ] ) .form @@ -71,9 +71,9 @@ def test_numpyarray(): .to_typetracer() ._mergemany( [ - ak.contents.NumpyArray(two), - ak.contents.EmptyArray(), - ak.contents.NumpyArray(four), + ak.contents.NumpyArray(two).to_typetracer(), + ak.contents.EmptyArray().to_typetracer(), + ak.contents.NumpyArray(four).to_typetracer(), ] ) .form @@ -115,11 +115,15 @@ def test_lists(): [4.0, 5.0], ] assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) assert ( - four.to_typetracer()._mergemany([three, four, one]).form + four.to_typetracer() + ._mergemany([three.to_typetracer(), four.to_typetracer(), one.to_typetracer()]) + .form == four._mergemany([three, four, one]).form ) @@ -147,12 +151,16 @@ def test_lists(): ] assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) assert ( - four.to_typetracer()._mergemany([three, four, one]).form + four.to_typetracer() + ._mergemany([three.to_typetracer(), four.to_typetracer(), one.to_typetracer()]) + .form == four._mergemany([three, four, one]).form ) @@ -175,7 +183,9 @@ def test_records(): {"x": 7, "y": [1, 2]}, ] assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) @@ -190,7 +200,9 @@ def test_records(): {"x": 7, "y": [1, 2]}, ] assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) @@ -210,7 +222,9 @@ def test_tuples(): (7, [1, 2]), ] assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) @@ -225,7 +239,9 @@ def test_tuples(): (7, [1, 2]), ] assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) @@ -252,7 +268,9 @@ def test_indexed(): None, ] assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) @@ -263,7 +281,9 @@ def test_reverse_indexed(): three = ak.highlevel.Array([None, 6, None]).layout assert to_list(one._mergemany([two, three])) == [1, 2, 3, 4, 5, None, 6, None] assert ( - one.to_typetracer()._mergemany([two, three]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer()]) + .form == one._mergemany([two, three]).form ) @@ -284,7 +304,9 @@ def test_reverse_indexed(): 9, ] assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) @@ -415,27 +437,39 @@ def test_bytemasked(): ] assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) assert ( - four.to_typetracer()._mergemany([three, two, one]).form + four.to_typetracer() + ._mergemany([three.to_typetracer(), two.to_typetracer(), one.to_typetracer()]) + .form == four._mergemany([three, two, one]).form ) assert ( - three.to_typetracer()._mergemany([four, one]).form + three.to_typetracer() + ._mergemany([four.to_typetracer(), one.to_typetracer()]) + .form == three._mergemany([four, one]).form ) assert ( - three.to_typetracer()._mergemany([four, one, two]).form + three.to_typetracer() + ._mergemany([four.to_typetracer(), one.to_typetracer(), two.to_typetracer()]) + .form == three._mergemany([four, one, two]).form ) assert ( - three.to_typetracer()._mergemany([two, one]).form + three.to_typetracer() + ._mergemany([two.to_typetracer(), one.to_typetracer()]) + .form == three._mergemany([two, one]).form ) assert ( - three.to_typetracer()._mergemany([two, one, four]).form + three.to_typetracer() + ._mergemany([two.to_typetracer(), one.to_typetracer(), four.to_typetracer()]) + .form == three._mergemany([two, one, four]).form ) @@ -460,15 +494,29 @@ def test_empty(): == one._mergemany([two]).form ) assert ( - one.to_typetracer()._mergemany([two, one, two, one, two]).form + one.to_typetracer() + ._mergemany( + [ + two.to_typetracer(), + one.to_typetracer(), + two.to_typetracer(), + one.to_typetracer(), + two.to_typetracer(), + ] + ) + .form == one._mergemany([two, one, two, one, two]).form ) assert ( - one.to_typetracer()._mergemany([two, three]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer()]) + .form == one._mergemany([two, three]).form ) assert ( - one.to_typetracer()._mergemany([two, three, four]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([two, three, four]).form ) assert ( @@ -476,19 +524,27 @@ def test_empty(): == one._mergemany([three]).form ) assert ( - one.to_typetracer()._mergemany([three, four]).form + one.to_typetracer() + ._mergemany([three.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([three, four]).form ) assert ( - one.to_typetracer()._mergemany([three, two]).form + one.to_typetracer() + ._mergemany([three.to_typetracer(), two.to_typetracer()]) + .form == one._mergemany([three, two]).form ) assert ( - one.to_typetracer()._mergemany([three, two, four]).form + one.to_typetracer() + ._mergemany([three.to_typetracer(), two.to_typetracer(), four.to_typetracer()]) + .form == one._mergemany([three, two, four]).form ) assert ( - one.to_typetracer()._mergemany([three, four, two]).form + one.to_typetracer() + ._mergemany([three.to_typetracer(), four.to_typetracer(), two.to_typetracer()]) + .form == one._mergemany([three, four, two]).form ) @@ -584,27 +640,39 @@ def test_union(): ] assert ( - one.to_typetracer()._mergemany([two, three]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer()]) + .form == one._mergemany([two, three]).form ) assert ( - one.to_typetracer()._mergemany([three, two]).form + one.to_typetracer() + ._mergemany([three.to_typetracer(), two.to_typetracer()]) + .form == one._mergemany([three, two]).form ) assert ( - two.to_typetracer()._mergemany([one, three]).form + two.to_typetracer() + ._mergemany([one.to_typetracer(), three.to_typetracer()]) + .form == two._mergemany([one, three]).form ) assert ( - two.to_typetracer()._mergemany([three, one]).form + two.to_typetracer() + ._mergemany([three.to_typetracer(), one.to_typetracer()]) + .form == two._mergemany([three, one]).form ) assert ( - three.to_typetracer()._mergemany([one, two]).form + three.to_typetracer() + ._mergemany([one.to_typetracer(), two.to_typetracer()]) + .form == three._mergemany([one, two]).form ) assert ( - three.to_typetracer()._mergemany([two, one]).form + three.to_typetracer() + ._mergemany([two.to_typetracer(), one.to_typetracer()]) + .form == three._mergemany([two, one]).form ) @@ -700,27 +768,39 @@ def test_union_option(): ] assert ( - one.to_typetracer()._mergemany([two, three]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer()]) + .form == one._mergemany([two, three]).form ) assert ( - one.to_typetracer()._mergemany([three, two]).form + one.to_typetracer() + ._mergemany([three.to_typetracer(), two.to_typetracer()]) + .form == one._mergemany([three, two]).form ) assert ( - two.to_typetracer()._mergemany([one, three]).form + two.to_typetracer() + ._mergemany([one.to_typetracer(), three.to_typetracer()]) + .form == two._mergemany([one, three]).form ) assert ( - two.to_typetracer()._mergemany([three, one]).form + two.to_typetracer() + ._mergemany([three.to_typetracer(), one.to_typetracer()]) + .form == two._mergemany([three, one]).form ) assert ( - three.to_typetracer()._mergemany([one, two]).form + three.to_typetracer() + ._mergemany([one.to_typetracer(), two.to_typetracer()]) + .form == three._mergemany([one, two]).form ) assert ( - three.to_typetracer()._mergemany([two, one]).form + three.to_typetracer() + ._mergemany([two.to_typetracer(), one.to_typetracer()]) + .form == three._mergemany([two, one]).form ) @@ -814,27 +894,39 @@ def test_union_option(): ] assert ( - one.to_typetracer()._mergemany([two, three]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer()]) + .form == one._mergemany([two, three]).form ) assert ( - one.to_typetracer()._mergemany([three, two]).form + one.to_typetracer() + ._mergemany([three.to_typetracer(), two.to_typetracer()]) + .form == one._mergemany([three, two]).form ) assert ( - two.to_typetracer()._mergemany([one, three]).form + two.to_typetracer() + ._mergemany([one.to_typetracer(), three.to_typetracer()]) + .form == two._mergemany([one, three]).form ) assert ( - two.to_typetracer()._mergemany([three, one]).form + two.to_typetracer() + ._mergemany([three.to_typetracer(), one.to_typetracer()]) + .form == two._mergemany([three, one]).form ) assert ( - three.to_typetracer()._mergemany([one, two]).form + three.to_typetracer() + ._mergemany([one.to_typetracer(), two.to_typetracer()]) + .form == three._mergemany([one, two]).form ) assert ( - three.to_typetracer()._mergemany([two, one]).form + three.to_typetracer() + ._mergemany([two.to_typetracer(), one.to_typetracer()]) + .form == three._mergemany([two, one]).form ) @@ -928,27 +1020,39 @@ def test_union_option(): ] assert ( - one.to_typetracer()._mergemany([two, three]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer()]) + .form == one._mergemany([two, three]).form ) assert ( - one.to_typetracer()._mergemany([three, two]).form + one.to_typetracer() + ._mergemany([three.to_typetracer(), two.to_typetracer()]) + .form == one._mergemany([three, two]).form ) assert ( - two.to_typetracer()._mergemany([one, three]).form + two.to_typetracer() + ._mergemany([one.to_typetracer(), three.to_typetracer()]) + .form == two._mergemany([one, three]).form ) assert ( - two.to_typetracer()._mergemany([three, one]).form + two.to_typetracer() + ._mergemany([three.to_typetracer(), one.to_typetracer()]) + .form == two._mergemany([three, one]).form ) assert ( - three.to_typetracer()._mergemany([one, two]).form + three.to_typetracer() + ._mergemany([one.to_typetracer(), two.to_typetracer()]) + .form == three._mergemany([one, two]).form ) assert ( - three.to_typetracer()._mergemany([two, one]).form + three.to_typetracer() + ._mergemany([two.to_typetracer(), one.to_typetracer()]) + .form == three._mergemany([two, one]).form ) @@ -971,7 +1075,9 @@ def test_strings(): ] assert ( - one.to_typetracer()._mergemany([two, three]).form + one.to_typetracer() + ._mergemany([two.to_typetracer(), three.to_typetracer()]) + .form == one._mergemany([two, three]).form ) diff --git a/tests/test_2192_union_absorb_indexed.py b/tests/test_2192_union_absorb_indexed.py index 51914f3162..07da097ae5 100644 --- a/tests/test_2192_union_absorb_indexed.py +++ b/tests/test_2192_union_absorb_indexed.py @@ -3,7 +3,7 @@ from __future__ import annotations import numpy as np -import pytest # noqa: F401 +import pytest import awkward as ak @@ -144,35 +144,64 @@ def test_merge_indexed_categorical(): ), ], ) + with pytest.raises( + NotImplementedError, + match=r"merging categorical arrays is currently not implemented", + ): + ak.concatenate((union, records), highlevel=False) + + +def test_merge_indexed_mixed_categorical(): + records = ak.contents.IndexedArray( + ak.index.Index64([0, 2, 3]), + ak.contents.RecordArray( + [ + ak.contents.NumpyArray( + np.array([4.0, 3.0, 1.0, 9.0, 8.0, 7.0], dtype=np.int64) + ) + ], + ["x"], + parameters={"inner": "bar", "drop": "this"}, + ), + parameters={"outer": "foo", "ignore": "me"}, + ) + union = ak.contents.UnionArray( + ak.index.Index8([0, 0, 0, 1, 1, 1]), + ak.index.Index64([0, 1, 2, 0, 1, 2]), + [ + ak.contents.NumpyArray(np.arange(10, dtype=np.int64)), + ak.contents.IndexedArray( + ak.index.Index64([0, 1, 2]), + ak.contents.RecordArray( + [ak.contents.NumpyArray(np.array([4.0, 3.0, 1.0], dtype=np.int64))], + ["x"], + parameters={"inner": "bar"}, + ), + parameters={"outer": "foo", "__array__": "categorical"}, + ), + ], + ) result = ak.concatenate((union, records), highlevel=False) assert result.is_equal_to( ak.contents.UnionArray( ak.index.Index8([0, 0, 0, 1, 1, 1, 1, 1, 1]), - ak.index.Index64([0, 1, 2, 0, 1, 2, 3, 4, 5]), + ak.index.Index64([0, 1, 2, 0, 1, 2, 3, 5, 6]), [ ak.contents.NumpyArray(np.arange(10, dtype=np.int64)), - ak.contents.IndexedArray( - ak.index.Index64([0, 1, 2, 3, 5, 6]), - ak.contents.RecordArray( - [ - ak.contents.NumpyArray( - np.array( - [4.0, 3.0, 1.0, 4.0, 3.0, 1.0, 9.0, 8.0, 7.0], - dtype=np.int64, - ) + ak.contents.RecordArray( + [ + ak.contents.NumpyArray( + np.array( + [4.0, 3.0, 1.0, 4.0, 3.0, 1.0, 9.0, 8.0, 7.0], + dtype=np.int64, ) - ], - ["x"], - ), - parameters={"__array__": "categorical"}, + ) + ], + ["x"], ), ], - ) + ), ) # This test might be a bit strict; any code that views `layout.parameters` will change this result to `{}` assert result.contents[0]._parameters is None - assert result.contents[1]._parameters == { - "__array__": "categorical", - "outer": "foo", - } - assert result.contents[1].content._parameters == {"inner": "bar"} + assert result.contents[1].parameters == {"outer": "foo", "inner": "bar"} diff --git a/tests/test_2860_enforce_concatenated_form.py b/tests/test_2860_enforce_concatenated_form.py new file mode 100644 index 0000000000..27685f01fc --- /dev/null +++ b/tests/test_2860_enforce_concatenated_form.py @@ -0,0 +1,120 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import numpy as np +import pytest + +import awkward as ak +from awkward.operations.ak_concatenate import enforce_concatenated_form + +layouts = [ + # ListArray + ak.contents.ListArray( + ak.index.Index64([0, 3]), + ak.index.Index64([3, 6]), + ak.contents.NumpyArray(np.arange(6, dtype=np.int64)), + ), + # ListArray + ak.contents.ListOffsetArray( + ak.index.Index64([0, 3, 6]), + ak.contents.NumpyArray(np.arange(6, dtype=np.int64)), + ), + # RegularArray + ak.contents.RegularArray(ak.contents.NumpyArray(np.arange(6, dtype=np.int64)), 3), + ak.contents.RegularArray(ak.contents.NumpyArray(np.arange(6, dtype=np.int64)), 2), + # ByteMaskedArray + ak.contents.ByteMaskedArray( + ak.index.Index8([True, False, False, True]), + ak.contents.NumpyArray(np.arange(6, dtype=np.int32)), + valid_when=True, + ), + # ByteMaskedArray + ak.contents.BitMaskedArray( + ak.index.IndexU8([1 << 0 | 1 << 1 | 0 << 2 | 0 << 3 | 1 << 4 | 0 << 5]), + ak.contents.NumpyArray(np.arange(6, dtype=np.int32)), + valid_when=True, + lsb_order=True, + length=6, + ), + # UnmaskedArray + ak.contents.UnmaskedArray(ak.contents.NumpyArray(np.arange(6, dtype=np.int32))), + # IndexedOptionArray + ak.contents.IndexedOptionArray( + ak.index.Index64([3, 1, -1, -1, 2, 0, -1]), + ak.contents.NumpyArray(np.arange(6, dtype=np.int32)), + ), + # NumpyArray + ak.contents.NumpyArray(np.arange(6, dtype=np.int16)), + ak.contents.NumpyArray(np.arange(6 * 4, dtype=np.float32).reshape(6, 4)), + # IndexedArray + ak.contents.IndexedArray( + ak.index.Index64([3, 1, 1, 0, 2, 0, 0]), + ak.contents.NumpyArray(np.arange(6, dtype=np.int32)), + ), + # RecordArray + ak.contents.RecordArray( + [ak.contents.NumpyArray(np.arange(6, dtype=np.int16))], ["x"] + ), + ak.contents.RecordArray( + [ak.contents.NumpyArray(np.arange(6, dtype=np.float64))], ["y"] + ), + ak.contents.RecordArray( + [ak.contents.NumpyArray(np.arange(6, dtype=np.float32))], None + ), + # UnionArray + ak.contents.UnionArray( + ak.index.Index8([0, 0, 1]), + ak.index.Index64([0, 1, 0]), + [ + ak.contents.NumpyArray(np.arange(6, dtype=np.int16)), + ak.contents.RecordArray( + [ak.contents.NumpyArray(np.arange(6, dtype=np.float32))], None + ), + ], + ), +] + + +@pytest.mark.parametrize("left", layouts) +@pytest.mark.parametrize("right", layouts) +def test_symmetric(left, right): + result = ak.concatenate([left, right], axis=0, highlevel=False) + part_0_result = enforce_concatenated_form(left, result.form) + assert part_0_result.form == result.form + + part_1_result = enforce_concatenated_form(right, result.form) + assert part_1_result.form == result.form + + assert part_0_result.to_list() == result[: part_0_result.length].to_list() + assert part_1_result.to_list() == result[part_0_result.length :].to_list() + + +@pytest.mark.parametrize( + "left, right", + [ + ( + # IndexedOptionArray + ak.contents.IndexedOptionArray( + ak.index.Index64([3, 1, -1, -1, 2, 0, -1]), + ak.contents.NumpyArray(np.arange(6, dtype=np.int32)), + parameters={"__array__": "categorical"}, + ), + # NumpyArray + ak.contents.NumpyArray(np.arange(6, dtype=np.int64)), + ), + ], +) +def test_non_diagonal(left, right): + result = ak.concatenate([left, right], axis=0, highlevel=False) + part_0_result = enforce_concatenated_form(left, result.form) + assert part_0_result.form == result.form + + part_1_result = enforce_concatenated_form(right, result.form) + assert part_1_result.form == result.form + + assert part_0_result.to_list() == result[: part_0_result.length].to_list() + assert part_1_result.to_list() == result[part_0_result.length :].to_list() + + +# def test_union_