Skip to content

Commit

Permalink
Manually fuse reindexing intermediates with blockwise reduction for c…
Browse files Browse the repository at this point in the history
…ohorts.
  • Loading branch information
dcherian committed May 2, 2024
1 parent 627bf2b commit aeb1f9e
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Callable,
Literal,
TypedDict,
TypeVar,
Union,
overload,
)
Expand Down Expand Up @@ -96,6 +97,7 @@
T_MethodOpt = None | Literal["map-reduce", "blockwise", "cohorts"]
T_IsBins = Union[bool | Sequence[bool]]

T = TypeVar("T")

IntermediateDict = dict[Union[str, Callable], Any]
FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]]
Expand Down Expand Up @@ -140,6 +142,10 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
return result


def identity(x: T) -> T:
return x


def _issorted(arr: np.ndarray) -> bool:
return bool((arr[:-1] <= arr[1:]).all())

Expand Down Expand Up @@ -1438,8 +1444,11 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:


def subset_to_blocks(
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
) -> DaskArray:
array: DaskArray,
flatblocks: Sequence[int],
blkshape: tuple[int] | None = None,
reindexer=identity,
) -> Graph:
"""
Advanced indexing of .blocks such that we always get a regular array back.
Expand All @@ -1464,20 +1473,21 @@ def subset_to_blocks(
index = _normalize_indexes(array, flatblocks, blkshape)

if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
return array
return dask.array.map_blocks(reindexer, array, meta=array._meta)

# These rest is copied from dask.array.core.py with slight modifications
index = normalize_index(index, array.numblocks)
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)

name = "blocks-" + tokenize(array, index)
name = "groupby-cohort-" + tokenize(array, index)
new_keys = array._key_array[index]

squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))

keys = itertools.product(*(range(len(c)) for c in chunks))
layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys}

graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])

return dask.array.Array(graph, name, chunks, meta=array)
Expand Down Expand Up @@ -1651,26 +1661,26 @@ def dask_groupby_agg(

elif method == "cohorts":
assert chunks_cohorts
block_shape = array.blocks.shape[-len(axis) :]

reduced_ = []
groups_ = []
for blks, cohort in chunks_cohorts.items():
index = pd.Index(cohort)
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
reindexed = dask.array.map_blocks(
reindex_intermediates, subset, agg, index, meta=subset._meta
)
cohort_index = pd.Index(cohort)
reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
# now that we have reindexed, we can set reindex=True explicitlly
reduced_.append(
tree_reduce(
reindexed,
combine=partial(combine, agg=agg, reindex=True),
aggregate=partial(aggregate, expected_groups=index, reindex=True),
aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True),
)
)
# This is done because pandas promotes to 64-bit types when an Index is created
# So we use the index to generate the return value for consistency with "map-reduce"
# This is important on windows
groups_.append(index.values)
groups_.append(cohort_index.values)

reduced = dask.array.concatenate(reduced_, axis=-1)
groups = (np.concatenate(groups_),)
Expand Down

0 comments on commit aeb1f9e

Please sign in to comment.