Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 27, 2023
1 parent 20b662a commit fa65721
Showing 1 changed file with 19 additions and 12 deletions.
31 changes: 19 additions & 12 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def slices_from_chunks(chunks):
return product(*slices)


@memoize
# @memoize
def find_group_cohorts(labels, chunks, merge: bool = True) -> dict:
"""
Finds groups labels that occur together aka "cohorts"
Expand Down Expand Up @@ -1330,8 +1330,11 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:


def subset_to_blocks(
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
) -> DaskArray:
array: DaskArray,
flatblocks: Sequence[int],
blkshape: tuple[int] | None = None,
reindexer=None,
) -> Graph:
"""
Advanced indexing of .blocks such that we always get a regular array back.
Expand Down Expand Up @@ -1362,14 +1365,18 @@ def subset_to_blocks(
index = normalize_index(index, array.numblocks)
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)

name = "blocks-" + tokenize(array, index)
name = "groupby-cohort-" + tokenize(array, index)
new_keys = array._key_array[index]

squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))

keys = itertools.product(*(range(len(c)) for c in chunks))
layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
if reindexer is None:
layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
else:
layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys}

graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])

return dask.array.Array(graph, name, chunks, meta=array)
Expand Down Expand Up @@ -1551,26 +1558,26 @@ def dask_groupby_agg(
chunks_cohorts = find_group_cohorts(
by_input, [array.chunks[ax] for ax in axis], merge=True
)
block_shape = array.blocks.shape[-len(axis) :]

reduced_ = []
groups_ = []
for blks, cohort in chunks_cohorts.items():
index = pd.Index(cohort)
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
reindexed = dask.array.map_blocks(
reindex_intermediates, subset, agg, index, meta=subset._meta
)
cohort_index = pd.Index(cohort)
reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
# now that we have reindexed, we can set reindex=True explicitlly
reduced_.append(
tree_reduce(
reindexed,
combine=partial(combine, agg=agg, reindex=True),
aggregate=partial(aggregate, expected_groups=index, reindex=True),
aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True),
)
)
# This is done because pandas promotes to 64-bit types when an Index is created
# So we use the index to generate the return value for consistency with "map-reduce"
# This is important on windows
groups_.append(index.values)
groups_.append(cohort_index.values)

reduced = dask.array.concatenate(reduced_, axis=-1)
groups = (np.concatenate(groups_),)
Expand Down

0 comments on commit fa65721

Please sign in to comment.