From 21eca160a5bf4032167bd973e9789420266b6c70 Mon Sep 17 00:00:00 2001 From: Martin Durant Date: Wed, 12 Jun 2024 21:30:50 -0400 Subject: [PATCH] Make simple types of repartition --- src/dask_awkward/lib/core.py | 84 ++++++++++++++++++++++++------- src/dask_awkward/lib/structure.py | 43 ++++++++++++++++ tests/test_structure.py | 6 +++ 3 files changed, 114 insertions(+), 19 deletions(-) diff --git a/src/dask_awkward/lib/core.py b/src/dask_awkward/lib/core.py index 40898b11..204ccc3f 100644 --- a/src/dask_awkward/lib/core.py +++ b/src/dask_awkward/lib/core.py @@ -942,29 +942,75 @@ def repartition( npartitions: int | None = None, divisions: tuple[int, ...] | None = None, rows_per_partition: int | None = None, + one_to_n: int | None = None, + n_to_one: int | None = None, ) -> Array: + """Restructure the partitioning of the whole array + + Various schemes are possible, with one of the mutually exclusive + optional arguments for each. Of these, the first three require + knowledge of the number of rows in each existing partition, which + will be eagerly computed if not already known, and some shuffling of + data between partitions. + + - npartitions: split all the rows as evenly as possible into this + many output partitions. + - divisions: exact row count offsets of each output partition + - rows_per_partition: each partition will have this many rows, + except the last, which will have this number or fewer + - one_to_n: each input partition becomes n output partitions + - n_to_one: every n adjacent input partitions becomes one + output partition. Note that exactly one output partition + (npartitions=1) is a special case of this. + """ from dask_awkward.layers import AwkwardMaterializedLayer - from dask_awkward.lib.structure import repartition_layer + from dask_awkward.lib.structure import ( + repartition_layer, + simple_repartition_layer, + ) - if sum(bool(_) for _ in [npartitions, divisions, rows_per_partition]) != 1: + if ( + sum( + bool(_) + for _ in ( + npartitions, + divisions, + rows_per_partition, + one_to_n, + n_to_one, + ) + ) + != 1 + ): raise ValueError("Please specify exactly one of the inputs") - if not self.known_divisions: - self.eager_compute_divisions() - nrows = self.defined_divisions[-1] - new_divisions: tuple[int, ...] = tuple() - if divisions: - new_divisions = divisions - elif npartitions: - rows_per_partition = math.ceil(nrows / npartitions) - if rows_per_partition: - new_divs = list(range(0, nrows, rows_per_partition)) - new_divs.append(nrows) - new_divisions = tuple(new_divs) - - token = tokenize(self, divisions) - key = f"repartition-{token}" - - new_layer_raw = repartition_layer(self, key, new_divisions) + new_divisions: tuple[int, ...] = () + if npartitions and npartitions == 1: + npartitions, n_to_one = None, self.npartitions + if n_to_one or one_to_n: + token = tokenize(self, n_to_one, one_to_n) + key = f"repartition-{token}" + new_layer_raw, new_divisions = simple_repartition_layer( + self, n_to_one, one_to_n, key + ) + else: + if not self.known_divisions: + self.eager_compute_divisions() + nrows = self.defined_divisions[-1] + if divisions: + if divisions == self.divisions: + # noop + return self + new_divisions = divisions + elif npartitions: + rows_per_partition = math.ceil(nrows / npartitions) + if rows_per_partition: + new_divs = list(range(0, nrows, rows_per_partition)) + new_divs.append(nrows) + new_divisions = tuple(new_divs) + token = tokenize(self, divisions) + key = f"repartition-{token}" + new_layer_raw = repartition_layer(self, key, new_divisions) + new_layer = AwkwardMaterializedLayer( new_layer_raw, previous_layer_names=[self.name], diff --git a/src/dask_awkward/lib/structure.py b/src/dask_awkward/lib/structure.py index ae5f104f..84b5ea8a 100644 --- a/src/dask_awkward/lib/structure.py +++ b/src/dask_awkward/lib/structure.py @@ -1350,6 +1350,49 @@ def repartition_layer(arr: Array, key: str, divisions: tuple[int, ...]) -> dict: return layer +def _subpart(data: ak.Array, parts: int, part: int) -> ak.Array: + from dask_awkward.lib.core import is_typetracer + + if is_typetracer(data): + return data + rows_per = len(data) // parts + return data[ + part * rows_per : None if part == (parts - 1) else (part + 1) * rows_per + ] + + +def _subcat(*arrs: tuple[ak.Array, ...]) -> ak.Array: + return ak.concatenate(arrs) + + +def simple_repartition_layer( + arr: Array, n_to_one: int | None, one_to_n: int | None, key: str +) -> tuple[dict, tuple[Any, ...]]: + layer: dict[tuple[str, int], tuple[Any, ...]] = {} + new_divisions: tuple[Any, ...] + if n_to_one: + for i in range(0, arr.npartitions, n_to_one): + layer[(key, i)] = (_subcat,) + tuple( + (arr.name, part) + for part in range(i, min(i + n_to_one, arr.npartitions)) + ) + new_divisions = arr.divisions[::n_to_one] + elif one_to_n: + for i in range(arr.npartitions): + for part in range(one_to_n): + layer[(key, (i * one_to_n + part))] = ( + _subpart, + (arr.name, i), + one_to_n, + part, + ) + # TODO: if arr.known_divisions: + new_divisions = (None,) * (arr.npartitions * one_to_n + 1) + else: + raise ValueError + return layer, new_divisions + + @borrow_docstring(ak.enforce_type) def enforce_type( array: Array, diff --git a/tests/test_structure.py b/tests/test_structure.py index d196ee35..3c9b1673 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -539,6 +539,12 @@ def test_repartition_whole(daa): assert_eq(daa, daa1, check_divisions=False) +def test_repartition_one_to_n(daa): + daa1 = daa.repartition(one_to_n=2) + assert daa1.npartitions == daa.npartitions * 2 + assert_eq(daa, daa1, check_divisions=False) + + def test_repartition_no_change(daa): daa1 = daa.repartition(divisions=(0, 5, 10, 15)) assert daa1.npartitions == 3