Skip to content

Commit

Permalink
Manually fuse reindexing intermediates with blockwise reduction for c…
Browse files Browse the repository at this point in the history
…ohorts. (#300)

* Manually fuse reindexing intermediates with blockwise reduction for cohorts.

```
| Change   | Before [627bf2b] <main>   | After [9d710529] <optimize-cohorts-graph>   |   Ratio | Benchmark (Parameter)                           |
|----------|----------------------------|---------------------------------------------|---------|-------------------------------------------------|
| -        | 3.39±0.02ms                | 2.98±0.01ms                                 |    0.88 | cohorts.PerfectMonthly.time_graph_construct     |
| -        | 20                         | 17                                          |    0.85 | cohorts.PerfectMonthly.track_num_layers         |
| -        | 23.0±0.07ms                | 19.0±0.1ms                                  |    0.83 | cohorts.ERA5Google.time_graph_construct         |
| -        | 4878                       | 3978                                        |    0.82 | cohorts.ERA5Google.track_num_tasks              |
| -        | 179±0.8ms                  | 147±0.5ms                                   |    0.82 | cohorts.OISST.time_graph_construct              |
| -        | 159                        | 128                                         |    0.81 | cohorts.ERA5Google.track_num_layers             |
| -        | 936                        | 762                                         |    0.81 | cohorts.PerfectMonthly.track_num_tasks          |
| -        | 1221                       | 978                                         |    0.8  | cohorts.OISST.track_num_layers                  |
| -        | 4929                       | 3834                                        |    0.78 | cohorts.ERA5DayOfYear.track_num_tasks           |
| -        | 351                        | 274                                         |    0.78 | cohorts.NWMMidwest.track_num_layers             |
| -        | 4562                       | 3468                                        |    0.76 | cohorts.ERA5DayOfYear.track_num_tasks_optimized |
| -        | 164±1ms                    | 118±0.4ms                                   |    0.72 | cohorts.ERA5DayOfYear.time_graph_construct      |
| -        | 1100                       | 735                                         |    0.67 | cohorts.ERA5DayOfYear.track_num_layers          |
| -        | 3930                       | 2605                                        |    0.66 | cohorts.NWMMidwest.track_num_tasks              |
| -        | 3715                       | 2409                                        |    0.65 | cohorts.NWMMidwest.track_num_tasks_optimized    |
| -        | 28952                      | 18798                                       |    0.65 | cohorts.OISST.track_num_tasks                   |
| -        | 27010                      | 16858                                       |    0.62 | cohorts.OISST.track_num_tasks_optimized         |
```

* fix typing
  • Loading branch information
dcherian authored May 2, 2024
1 parent 2439c5c commit eb3c0ef
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 19 deletions.
12 changes: 6 additions & 6 deletions asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def setup(self, *args, **kwargs):
raise NotImplementedError

@cached_property
def dask(self):
return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0].dask
def result(self):
return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0]

def containment(self):
asfloat = self.bitmask().astype(float)
Expand Down Expand Up @@ -52,14 +52,14 @@ def time_graph_construct(self):
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)

def track_num_tasks(self):
return len(self.dask.to_dict())
return len(self.result.dask.to_dict())

def track_num_tasks_optimized(self):
(opt,) = dask.optimize(self.dask)
return len(opt.to_dict())
(opt,) = dask.optimize(self.result)
return len(opt.dask.to_dict())

def track_num_layers(self):
return len(self.dask.layers)
return len(self.result.dask.layers)

track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
Expand Down
32 changes: 21 additions & 11 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Callable,
Literal,
TypedDict,
TypeVar,
Union,
overload,
)
Expand Down Expand Up @@ -96,6 +97,7 @@
T_MethodOpt = None | Literal["map-reduce", "blockwise", "cohorts"]
T_IsBins = Union[bool | Sequence[bool]]

T = TypeVar("T")

IntermediateDict = dict[Union[str, Callable], Any]
FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]]
Expand Down Expand Up @@ -140,6 +142,10 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
return result


def identity(x: T) -> T:
return x


def _issorted(arr: np.ndarray) -> bool:
return bool((arr[:-1] <= arr[1:]).all())

Expand Down Expand Up @@ -1438,7 +1444,10 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:


def subset_to_blocks(
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
array: DaskArray,
flatblocks: Sequence[int],
blkshape: tuple[int] | None = None,
reindexer=identity,
) -> DaskArray:
"""
Advanced indexing of .blocks such that we always get a regular array back.
Expand All @@ -1464,20 +1473,21 @@ def subset_to_blocks(
index = _normalize_indexes(array, flatblocks, blkshape)

if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
return array
return dask.array.map_blocks(reindexer, array, meta=array._meta)

# These rest is copied from dask.array.core.py with slight modifications
index = normalize_index(index, array.numblocks)
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)

name = "blocks-" + tokenize(array, index)
name = "groupby-cohort-" + tokenize(array, index)
new_keys = array._key_array[index]

squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))

keys = itertools.product(*(range(len(c)) for c in chunks))
layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys}

graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])

return dask.array.Array(graph, name, chunks, meta=array)
Expand Down Expand Up @@ -1651,26 +1661,26 @@ def dask_groupby_agg(

elif method == "cohorts":
assert chunks_cohorts
block_shape = array.blocks.shape[-len(axis) :]

reduced_ = []
groups_ = []
for blks, cohort in chunks_cohorts.items():
index = pd.Index(cohort)
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
reindexed = dask.array.map_blocks(
reindex_intermediates, subset, agg, index, meta=subset._meta
)
cohort_index = pd.Index(cohort)
reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
# now that we have reindexed, we can set reindex=True explicitlly
reduced_.append(
tree_reduce(
reindexed,
combine=partial(combine, agg=agg, reindex=True),
aggregate=partial(aggregate, expected_groups=index, reindex=True),
aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True),
)
)
# This is done because pandas promotes to 64-bit types when an Index is created
# So we use the index to generate the return value for consistency with "map-reduce"
# This is important on windows
groups_.append(index.values)
groups_.append(cohort_index.values)

reduced = dask.array.concatenate(reduced_, axis=-1)
groups = (np.concatenate(groups_),)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,14 +1465,18 @@ def test_normalize_block_indexing_2d(flatblocks, expected):

@requires_dask
def test_subset_block_passthrough():
from flox.core import identity

# full slice pass through
array = dask.array.ones((5,), chunks=(1,))
expected = dask.array.map_blocks(identity, array)
subset = subset_to_blocks(array, np.arange(5))
assert subset.name == array.name
assert subset.name == expected.name

array = dask.array.ones((5, 5), chunks=1)
expected = dask.array.map_blocks(identity, array)
subset = subset_to_blocks(array, np.arange(25))
assert subset.name == array.name
assert subset.name == expected.name


@requires_dask
Expand Down

0 comments on commit eb3c0ef

Please sign in to comment.