Skip to content

Commit

Permalink
fix: touching of unions
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Dec 16, 2023
1 parent accbce3 commit bed2688
Showing 1 changed file with 62 additions and 109 deletions.
171 changes: 62 additions & 109 deletions src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,121 +773,74 @@ def broadcast_any_union():
else:
nextparameters.append(NO_PARAMETERS)

if not backend.nplike.known_data:
# assert False
union_num_contents = []
length = None
for x in contents:
if x.is_union:
x._touch_data(recursive=False)
union_num_contents.append(len(x.contents))
if length is None:
length = x.length

all_combos = list(
itertools.product(*[range(x) for x in union_num_contents])
)

tags = backend.index_nplike.empty(length, dtype=np.int8)
index = backend.index_nplike.empty(length, dtype=np.int64)
numoutputs = None
outcontents = []
for combo in all_combos:
nextinputs = []
i = 0
for x in inputs:
if isinstance(x, UnionArray):
nextinputs.append(x._contents[combo[i]])
i += 1
else:
nextinputs.append(x)
assert len(nextinputs) == len(nextparameters)
outcontents.append(
apply_step(
backend,
nextinputs,
action,
depth,
copy.copy(depth_context),
lateral_context,
options,
union_tags, union_num_contents, length = [], [], unknown_length
for x in contents:
if x.is_union:
tags = x.tags.raw(backend.index_nplike)
union_tags.append(tags)
union_num_contents.append(len(x.contents))

if length is unknown_length:
length = tags.shape[0]
elif tags.shape[0] is unknown_length:
continue
elif length != tags.shape[0]:
raise ValueError(
"cannot broadcast UnionArray of length {} "
"with UnionArray of length {}{}".format(
length,
tags.shape[0],
in_function(options),
)
)
)
assert isinstance(outcontents[-1], tuple)
if numoutputs is not None:
assert numoutputs == len(outcontents[-1])
numoutputs = len(outcontents[-1])

assert numoutputs is not None
tags = backend.index_nplike.empty(length, dtype=np.int8)
index = backend.index_nplike.empty(length, dtype=np.int64)

else:
union_tags, union_num_contents, length = [], [], None
for x in contents:
if x.is_union:
tags = x.tags.raw(backend.index_nplike)
union_tags.append(tags)
union_num_contents.append(len(x.contents))
if tags.shape[0] is unknown_length:
continue

if length is None:
length = tags.shape[0]
elif length != tags.shape[0]:
raise ValueError(
"cannot broadcast UnionArray of length {} "
"with UnionArray of length {}{}".format(
length,
tags.shape[0],
in_function(options),
)
)
assert length is not unknown_length

# Stack all union tags
combos = backend.index_nplike.stack(union_tags, axis=-1)
# Build array of indices (c1, c2, c3, ..., cn) of contents in
# (union 1, union 2, union 3, ..., union n)
all_combos = backend.index_nplike.asarray(
list(itertools.product(*[range(x) for x in union_num_contents]))
)
# Stack all union tags
combos = backend.index_nplike.stack(union_tags, axis=-1)

tags = backend.index_nplike.empty(length, dtype=np.int8)
index = backend.index_nplike.empty(length, dtype=np.int64)
numoutputs, outcontents = None, []
for tag, combo in enumerate(all_combos):
mask = backend.index_nplike.all(combos == combo, axis=-1)
tags[mask] = tag
index[mask] = backend.index_nplike.arange(
backend.index_nplike.count_nonzero(mask), dtype=np.int64
)
nextinputs = []
i = 0
for x in inputs:
if isinstance(x, UnionArray):
nextinputs.append(x[mask].project(combo[i]))
i += 1
elif isinstance(x, Content):
nextinputs.append(x[mask])
else:
nextinputs.append(x)
outcontents.append(
apply_step(
backend,
nextinputs,
action,
depth,
copy.copy(depth_context),
lateral_context,
options,
)
)
assert isinstance(outcontents[-1], tuple)
if numoutputs is None:
numoutputs = len(outcontents[-1])
# Build array of indices (c1, c2, c3, ..., cn) of contents in
# (union 1, union 2, union 3, ..., union n)
all_combos = list(itertools.product(*[range(x) for x in union_num_contents]))

numoutputs = None
outcontents = []

for tag, j_contents in enumerate(all_combos):
combo = backend.index_nplike.asarray(j_contents, dtype=np.int64)
mask = backend.index_nplike.all(combos == combo, axis=-1)
tags[mask] = tag
index[mask] = backend.index_nplike.arange(
backend.index_nplike.count_nonzero(mask), dtype=np.int64
)
nextinputs = []
it_j_contents = iter(j_contents)
for x in inputs:
if isinstance(x, UnionArray):
nextinputs.append(x[mask].project(next(it_j_contents)))
elif isinstance(x, Content):
nextinputs.append(x[mask])
else:
assert numoutputs == len(outcontents[-1])
nextinputs.append(x)
outcontents.append(
apply_step(
backend,
nextinputs,
action,
depth,
copy.copy(depth_context),
lateral_context,
options,
)
)
assert isinstance(outcontents[-1], tuple)
if numoutputs is None:
numoutputs = len(outcontents[-1])
else:
assert numoutputs == len(outcontents[-1])

assert numoutputs is not None
assert numoutputs is not None

parameters = parameters_factory(nextparameters, numoutputs)

Expand Down

0 comments on commit bed2688

Please sign in to comment.