Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manually fuse reindexing intermediates with blockwise reduction for cohorts. #300

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def setup(self, *args, **kwargs):
raise NotImplementedError

@cached_property
def dask(self):
return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0].dask
def result(self):
return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0]

def containment(self):
asfloat = self.bitmask().astype(float)
Expand Down Expand Up @@ -52,14 +52,14 @@ def time_graph_construct(self):
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)

def track_num_tasks(self):
return len(self.dask.to_dict())
return len(self.result.dask.to_dict())

def track_num_tasks_optimized(self):
(opt,) = dask.optimize(self.dask)
return len(opt.to_dict())
(opt,) = dask.optimize(self.result)
return len(opt.dask.to_dict())

def track_num_layers(self):
return len(self.dask.layers)
return len(self.result.dask.layers)

track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
Expand Down
32 changes: 21 additions & 11 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Callable,
Literal,
TypedDict,
TypeVar,
Union,
overload,
)
Expand Down Expand Up @@ -96,6 +97,7 @@
T_MethodOpt = None | Literal["map-reduce", "blockwise", "cohorts"]
T_IsBins = Union[bool | Sequence[bool]]

T = TypeVar("T")

IntermediateDict = dict[Union[str, Callable], Any]
FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]]
Expand Down Expand Up @@ -140,6 +142,10 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
return result


def identity(x: T) -> T:
return x


def _issorted(arr: np.ndarray) -> bool:
return bool((arr[:-1] <= arr[1:]).all())

Expand Down Expand Up @@ -1438,7 +1444,10 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:


def subset_to_blocks(
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
array: DaskArray,
flatblocks: Sequence[int],
blkshape: tuple[int] | None = None,
reindexer=identity,
) -> DaskArray:
"""
Advanced indexing of .blocks such that we always get a regular array back.
Expand All @@ -1464,20 +1473,21 @@ def subset_to_blocks(
index = _normalize_indexes(array, flatblocks, blkshape)

if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
return array
return dask.array.map_blocks(reindexer, array, meta=array._meta)

# These rest is copied from dask.array.core.py with slight modifications
index = normalize_index(index, array.numblocks)
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)

name = "blocks-" + tokenize(array, index)
name = "groupby-cohort-" + tokenize(array, index)
new_keys = array._key_array[index]

squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))

keys = itertools.product(*(range(len(c)) for c in chunks))
layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys}

graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])

return dask.array.Array(graph, name, chunks, meta=array)
Expand Down Expand Up @@ -1651,26 +1661,26 @@ def dask_groupby_agg(

elif method == "cohorts":
assert chunks_cohorts
block_shape = array.blocks.shape[-len(axis) :]

reduced_ = []
groups_ = []
for blks, cohort in chunks_cohorts.items():
index = pd.Index(cohort)
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
reindexed = dask.array.map_blocks(
reindex_intermediates, subset, agg, index, meta=subset._meta
)
cohort_index = pd.Index(cohort)
reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
# now that we have reindexed, we can set reindex=True explicitlly
reduced_.append(
tree_reduce(
reindexed,
combine=partial(combine, agg=agg, reindex=True),
aggregate=partial(aggregate, expected_groups=index, reindex=True),
aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True),
)
)
# This is done because pandas promotes to 64-bit types when an Index is created
# So we use the index to generate the return value for consistency with "map-reduce"
# This is important on windows
groups_.append(index.values)
groups_.append(cohort_index.values)

reduced = dask.array.concatenate(reduced_, axis=-1)
groups = (np.concatenate(groups_),)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,14 +1465,18 @@ def test_normalize_block_indexing_2d(flatblocks, expected):

@requires_dask
def test_subset_block_passthrough():
from flox.core import identity

# full slice pass through
array = dask.array.ones((5,), chunks=(1,))
expected = dask.array.map_blocks(identity, array)
subset = subset_to_blocks(array, np.arange(5))
assert subset.name == array.name
assert subset.name == expected.name

array = dask.array.ones((5, 5), chunks=1)
expected = dask.array.map_blocks(identity, array)
subset = subset_to_blocks(array, np.arange(25))
assert subset.name == array.name
assert subset.name == expected.name


@requires_dask
Expand Down
Loading