Skip to content

Commit

Permalink
Merge pull request #517 from martindurant/n_to_one
Browse files Browse the repository at this point in the history
fix: n_to_one repartition
  • Loading branch information
martindurant authored Jun 19, 2024
2 parents 1c4b97a + 5a0af7b commit 38c065b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/dask_awkward/lib/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,12 +1412,17 @@ def simple_repartition_layer(
layer: dict[tuple[str, int], tuple[Any, ...]] = {}
new_divisions: tuple[Any, ...]
if n_to_one:
for i in range(0, arr.npartitions, n_to_one):
layer[(key, i)] = (_subcat,) + tuple(
for i0, i in enumerate(range(0, arr.npartitions, n_to_one)):
layer[(key, i0)] = (_subcat,) + tuple(
(arr.name, part)
for part in range(i, min(i + n_to_one, arr.npartitions))
)
new_divisions = arr.divisions[::n_to_one]
if arr.npartitions % n_to_one:
new_divisions = new_divisions + (arr.divisions[-1],)
layer[(key, i0 + 1)] = (_subcat,) + tuple(
(arr.name, part) for part in range(new_divisions[-2], new_divisions[-1])
)
elif one_to_n:
for i in range(arr.npartitions):
for part in range(one_to_n):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,22 @@ def test_repartition_one_to_n(daa):
assert_eq(daa, daa1, check_divisions=False)


def test_repartition_n_to_one():
daa = dak.from_lists([[[1, 2, 3], [], [4, 5]]] * 52)
daa2 = daa.repartition(n_to_one=52)
assert daa2.npartitions == 1
assert daa.compute().to_list() == daa2.compute().to_list()
daa2 = daa.repartition(n_to_one=53)
assert daa2.npartitions == 1
assert daa.compute().to_list() == daa2.compute().to_list()
daa2 = daa.repartition(n_to_one=2)
assert daa2.npartitions == 26
assert daa.compute().to_list() == daa2.compute().to_list()
daa2 = daa.repartition(n_to_one=10)
assert daa2.npartitions == 6
assert daa.compute().to_list() == daa2.compute().to_list()


def test_repartition_no_change(daa):
daa1 = daa.repartition(divisions=(0, 5, 10, 15))
assert daa1.npartitions == 3
Expand Down

0 comments on commit 38c065b

Please sign in to comment.