Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GroupBy.shuffle_to_chunks() #9320

Merged
merged 61 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
3bc51bd
Add GroupBy.shuffle()
dcherian Aug 7, 2024
60d7619
Cleanup
dcherian Aug 7, 2024
d1429cd
Cleanup
dcherian Aug 7, 2024
31fc00e
fix
dcherian Aug 7, 2024
4583853
return groupby instance from shuffle
dcherian Aug 13, 2024
abd9dd2
Fix nD by
dcherian Aug 13, 2024
6b820aa
Merge branch 'main' into groupby-shuffle
dcherian Aug 14, 2024
0d70656
Skip if no dask
dcherian Aug 14, 2024
fafb937
fix tests
dcherian Aug 14, 2024
939db9a
Merge branch 'main' into groupby-shuffle
dcherian Aug 14, 2024
a08450e
Add `chunks` to signature
dcherian Aug 14, 2024
d0cd218
FIx self
dcherian Aug 14, 2024
4edc976
Another Self fix
dcherian Aug 14, 2024
0b42be4
Forward chunks too
dcherian Aug 14, 2024
c52734d
[revert]
dcherian Aug 14, 2024
8180625
undo flox limit
dcherian Aug 14, 2024
7897c91
[revert]
dcherian Aug 14, 2024
7773548
fix types
dcherian Aug 14, 2024
51a7723
Add DataArray.shuffle_by, Dataset.shuffle_by
dcherian Aug 15, 2024
cc95513
Add doctest
dcherian Aug 15, 2024
18f4a40
Refactor
dcherian Aug 15, 2024
f489bcf
tweak docstrings
dcherian Aug 15, 2024
ead1bb4
fix typing
dcherian Aug 15, 2024
75115d0
Fix
dcherian Aug 15, 2024
390863a
fix docstring
dcherian Aug 15, 2024
a408cb0
bump min version to dask>=2024.08.1
dcherian Aug 17, 2024
7038f37
Merge branch 'main' into groupby-shuffle
dcherian Aug 17, 2024
05a0fb4
Fix typing
dcherian Aug 17, 2024
b8e7f62
Fix types
dcherian Aug 17, 2024
6d9ed1c
Merge branch 'main' into groupby-shuffle
dcherian Aug 22, 2024
20a8cd9
Merge branch 'main' into groupby-shuffle
dcherian Aug 30, 2024
7a99c8f
remove shuffle_by for now.
dcherian Aug 30, 2024
5e2fdfb
Add tests
dcherian Aug 30, 2024
a22c7ed
Support shuffling with multiple groupers
dcherian Aug 30, 2024
2d48690
Revert "remove shuffle_by for now."
dcherian Sep 11, 2024
0679d2b
Merge branch 'main' into groupby-shuffle
dcherian Sep 12, 2024
63b3e77
Merge branch 'main' into groupby-shuffle
dcherian Sep 17, 2024
7dc5dd1
bad merge
dcherian Sep 17, 2024
bad0744
Merge branch 'main' into groupby-shuffle
dcherian Sep 18, 2024
91e4bd8
Add a test
dcherian Sep 18, 2024
0542944
Merge branch 'main' into groupby-shuffle
dcherian Nov 1, 2024
1e4f805
Add docs
dcherian Nov 1, 2024
ad502aa
bugfix
dcherian Nov 1, 2024
4b0c143
Refactor out Dataset._shuffle
dcherian Nov 2, 2024
2b2c4ab
Merge branch 'main' into groupby-shuffle
dcherian Nov 3, 2024
f624c8f
fix types
dcherian Nov 3, 2024
888e780
fix tests
dcherian Nov 3, 2024
47e5c17
Merge branch 'main' into groupby-shuffle
dcherian Nov 4, 2024
b100fb1
Handle by is chunked
dcherian Nov 4, 2024
978fad9
Merge branch 'main' into groupby-shuffle
dcherian Nov 7, 2024
d1a3fc1
Some refactoring
dcherian Nov 7, 2024
23b0cac
Merge branch 'main' into groupby-shuffle
dcherian Nov 12, 2024
d533638
Merge branch 'main' into groupby-shuffle
dcherian Nov 19, 2024
d467bc6
Remove shuffle_by
dcherian Nov 19, 2024
231533c
shuffle -> distributed_shuffle
dcherian Nov 19, 2024
c77d7c5
return xarray object from distributed_shuffle
dcherian Nov 19, 2024
bccacfe
fix
dcherian Nov 19, 2024
2d4392a
fix doctest
dcherian Nov 19, 2024
003e9f2
fix api
dcherian Nov 19, 2024
0f80c81
Rename to `shuffle_to_chunks`
dcherian Nov 20, 2024
88bef5d
update docs
dcherian Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,7 @@ Dataset
DatasetGroupBy.var
DatasetGroupBy.dims
DatasetGroupBy.groups
DatasetGroupBy.shuffle_to_chunks

DataArray
---------
Expand Down Expand Up @@ -1241,6 +1242,7 @@ DataArray
DataArrayGroupBy.var
DataArrayGroupBy.dims
DataArrayGroupBy.groups
DataArrayGroupBy.shuffle_to_chunks

Grouper Objects
---------------
Expand Down
21 changes: 21 additions & 0 deletions doc/user-guide/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,24 @@ Different groupers can be combined to construct sophisticated GroupBy operations
from xarray.groupers import BinGrouper

ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()


Shuffling
~~~~~~~~~

Shuffling is a generalization of sorting a DataArray or Dataset by another DataArray, named ``label`` for example, that follows from the idea of grouping by ``label``.
Shuffling reorders the DataArray or the DataArrays in a Dataset such that all members of a group occur sequentially. For example,
Shuffle the object using either :py:class:`DatasetGroupBy` or :py:class:`DataArrayGroupBy` as appropriate.

.. ipython:: python

da = xr.DataArray(
dims="x",
data=[1, 2, 3, 4, 5, 6],
coords={"label": ("x", "a b c a b c".split(" "))},
)
da.groupby("label").shuffle_to_chunks()


For chunked array types (e.g. dask or cubed), shuffle may result in a more optimized communication pattern when compared to direct indexing by the appropriate indexer.
Shuffling also makes GroupBy operations on chunked arrays an embarrassingly parallel problem, and may significantly improve workloads that use :py:meth:`DatasetGroupBy.map` or :py:meth:`DataArrayGroupBy.map`.
8 changes: 8 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
Bins,
DaCompatible,
NetcdfWriteModes,
T_Chunks,
T_DataArray,
T_DataArrayOrSet,
ZarrWriteModes,
Expand Down Expand Up @@ -105,6 +106,7 @@
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
GroupIndices,
GroupInput,
InterpOptions,
PadModeOptions,
Expand Down Expand Up @@ -1687,6 +1689,12 @@ def sel(
)
return self._from_temp_dataset(ds)

def _shuffle(
self, dim: Hashable, *, indices: GroupIndices, chunks: T_Chunks
) -> Self:
ds = self._to_temp_dataset()._shuffle(dim=dim, indices=indices, chunks=chunks)
return self._from_temp_dataset(ds)

def head(
self,
indexers: Mapping[Any, int] | int | None = None,
Expand Down
34 changes: 34 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
DsCompatible,
ErrorOptions,
ErrorOptionsWithWarn,
GroupIndices,
GroupInput,
InterpOptions,
JoinOptions,
Expand All @@ -166,6 +167,7 @@
ResampleCompatible,
SideOptions,
T_ChunkDimFreq,
T_Chunks,
T_DatasetPadConstantValues,
T_Xarray,
)
Expand Down Expand Up @@ -3237,6 +3239,38 @@ def sel(
result = self.isel(indexers=query_results.dim_indexers, drop=drop)
return result._overwrite_indexes(*query_results.as_tuple()[1:])

def _shuffle(self, dim, *, indices: GroupIndices, chunks: T_Chunks) -> Self:
# Shuffling is only different from `isel` for chunked arrays.
# Extract them out, and treat them specially. The rest, we route through isel.
# This makes it easy to ensure correct handling of indexes.
is_chunked = {
name: var
for name, var in self._variables.items()
if is_chunked_array(var._data)
}
subset = self[[name for name in self._variables if name not in is_chunked]]

no_slices: list[list[int]] = [
list(range(*idx.indices(self.sizes[dim])))
if isinstance(idx, slice)
else idx
for idx in indices
]
no_slices = [idx for idx in no_slices if idx]

shuffled = (
subset
if dim not in subset.dims
else subset.isel({dim: np.concatenate(no_slices)})
)
for name, var in is_chunked.items():
shuffled[name] = var._shuffle(
indices=no_slices,
dim=dim,
chunks=chunks,
)
return shuffled

def head(
self,
indexers: Mapping[Any, int] | int | None = None,
Expand Down
82 changes: 80 additions & 2 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey
from xarray.core.types import (
GroupIndex,
GroupIndices,
GroupInput,
GroupKey,
T_Chunks,
)
from xarray.core.utils import Frozen
from xarray.groupers import EncodedGroups, Grouper

Expand Down Expand Up @@ -676,6 +682,76 @@ def sizes(self) -> Mapping[Hashable, int]:
self._sizes = self._obj.isel({self._group_dim: index}).sizes
return self._sizes

def shuffle_to_chunks(self, chunks: T_Chunks = None) -> T_Xarray:
"""
Sort or "shuffle" the underlying object.

"Shuffle" means the object is sorted so that all group members occur sequentially,
in the same chunk. Multiple groups may occur in the same chunk.
This method is particularly useful for chunked arrays (e.g. dask, cubed).
particularly when you need to map a function that requires all members of a group
to be present in a single chunk. For chunked array types, the order of appearance
is not guaranteed, but will depend on the input chunking.
Comment on lines +689 to +694
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, chunks will then be different sizes from each other. So when writing to Zarr we'll need to re-chunk? (asking for my clarification, feel free to not respond if it's obvious / respond with a single word :) )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but that's a zarr limitation :)

Comment on lines +689 to +694
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is a single group limited to a single chunk? Assuming so, if we get one giant chuck, could that present any performance problems?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other chunks get "auto" reshaped. This is controlled by the chunks kwarg, which only takes "auto" at the moment.

https://docs.dask.org/en/latest/generated/dask.array.shuffle.html


Parameters
----------
chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional
How to adjust chunks along dimensions not present in the array being grouped by.

Returns
-------
DataArrayGroupBy or DatasetGroupBy

Examples
--------
>>> import dask.array
>>> da = xr.DataArray(
... dims="x",
... data=dask.array.arange(10, chunks=3),
... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
... name="a",
... )
>>> shuffled = da.groupby("x").shuffle_to_chunks()
>>> shuffled
<xarray.DataArray 'a' (x: 10)> Size: 80B
dask.array<shuffle, shape=(10,), dtype=int64, chunksize=(3,), chunktype=numpy.ndarray>
Coordinates:
* x (x) int64 80B 0 1 1 1 2 2 2 3 3 3

>>> shuffled.groupby("x").quantile(q=0.5).compute()
<xarray.DataArray 'a' (x: 4)> Size: 32B
array([9., 3., 4., 5.])
Coordinates:
quantile float64 8B 0.5
* x (x) int64 32B 0 1 2 3

See Also
--------
dask.dataframe.DataFrame.shuffle
dask.array.shuffle
"""
self._raise_if_by_is_chunked()
return self._shuffle_obj(chunks)

def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
from xarray.core.dataarray import DataArray

was_array = isinstance(self._obj, DataArray)
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj

for grouper in self.groupers:
if grouper.name not in as_dataset._variables:
as_dataset.coords[grouper.name] = grouper.group

shuffled = as_dataset._shuffle(
dim=self._group_dim, indices=self.encoded.group_indices, chunks=chunks
)
unstacked: Dataset = self._maybe_unstack(shuffled)
if was_array:
return self._obj._from_temp_dataset(unstacked)
else:
return unstacked # type: ignore[return-value]

def map(
self,
func: Callable,
Expand Down Expand Up @@ -896,7 +972,9 @@ def _maybe_unstack(self, obj):
# and `inserted_dims`
# if multiple groupers all share the same single dimension, then
# we don't stack/unstack. Do that manually now.
obj = obj.unstack(*self.encoded.unique_coord.dims)
dims_to_unstack = self.encoded.unique_coord.dims
if all(dim in obj.dims for dim in dims_to_unstack):
obj = obj.unstack(*dims_to_unstack)
to_drop = [
grouper.name
for grouper in self.groupers
Expand Down
45 changes: 45 additions & 0 deletions xarray/core/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
if TYPE_CHECKING:
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import T_Chunks

from xarray.groupers import RESAMPLE_DIM

Expand Down Expand Up @@ -58,6 +59,50 @@ def _flox_reduce(
result = result.rename({RESAMPLE_DIM: self._group_dim})
return result

def shuffle_to_chunks(self, chunks: T_Chunks = None):
"""
Sort or "shuffle" the underlying object.

"Shuffle" means the object is sorted so that all group members occur sequentially,
in the same chunk. Multiple groups may occur in the same chunk.
This method is particularly useful for chunked arrays (e.g. dask, cubed).
particularly when you need to map a function that requires all members of a group
to be present in a single chunk. For chunked array types, the order of appearance
is not guaranteed, but will depend on the input chunking.

Parameters
----------
chunks : int, tuple of int, "auto" or mapping of hashable to int or tuple of int, optional
How to adjust chunks along dimensions not present in the array being grouped by.

Returns
-------
DataArrayGroupBy or DatasetGroupBy

Examples
--------
>>> import dask.array
>>> da = xr.DataArray(
... dims="time",
... data=dask.array.arange(10, chunks=1),
... coords={"time": xr.date_range("2001-01-01", freq="12h", periods=10)},
... name="a",
... )
>>> shuffled = da.resample(time="2D").shuffle_to_chunks()
>>> shuffled
<xarray.DataArray 'a' (time: 10)> Size: 80B
dask.array<shuffle, shape=(10,), dtype=int64, chunksize=(4,), chunktype=numpy.ndarray>
Coordinates:
* time (time) datetime64[ns] 80B 2001-01-01 ... 2001-01-05T12:00:00

See Also
--------
dask.dataframe.DataFrame.shuffle
dask.array.shuffle
"""
(grouper,) = self.groupers
return self._shuffle_obj(chunks).drop_vars(RESAMPLE_DIM)

def _drop_coords(self) -> T_Xarray:
"""Drop non-dimension coordinates along the resampled dimension."""
obj = self._obj
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def read(self, __n: int = ...) -> AnyStr_co:
ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]

GroupKey = Any
GroupIndex = Union[int, slice, list[int]]
GroupIndex = Union[slice, list[int]]
GroupIndices = tuple[GroupIndex, ...]
Bins = Union[
int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index
Expand Down
26 changes: 25 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@
maybe_coerce_to_str,
)
from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions
from xarray.namedarray.pycompat import integer_types, is_0d_dask_array, to_duck_array
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import (
integer_types,
is_0d_dask_array,
is_chunked_array,
to_duck_array,
)
from xarray.namedarray.utils import module_available
from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims

Expand Down Expand Up @@ -1019,6 +1025,24 @@ def compute(self, **kwargs):
new = self.copy(deep=False)
return new.load(**kwargs)

def _shuffle(
self, indices: list[list[int]], dim: Hashable, chunks: T_Chunks
) -> Self:
# TODO (dcherian): consider making this public API
array = self._data
if is_chunked_array(array):
chunkmanager = get_chunked_array_type(array)
return self._replace(
data=chunkmanager.shuffle(
array,
indexer=indices,
axis=self.get_axis_num(dim),
chunks=chunks,
)
)
else:
return self.isel({dim: np.concatenate(indices)})

def isel(
self,
indexers: Mapping[Any, Any] | None = None,
Expand Down
Loading
Loading