Skip to content

Commit

Permalink
return xarray object from distributed_shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 19, 2024
1 parent 231533c commit c77d7c5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 32 deletions.
18 changes: 4 additions & 14 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ def sizes(self) -> Mapping[Hashable, int]:
self._sizes = self._obj.isel({self._group_dim: index}).sizes
return self._sizes

def distributed_shuffle(self, chunks: T_Chunks = None):
def distributed_shuffle(self, chunks: T_Chunks = None) -> T_Xarray:
"""
Sort or "shuffle" the underlying object.
Expand Down Expand Up @@ -711,8 +711,8 @@ def distributed_shuffle(self, chunks: T_Chunks = None):
... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
... name="a",
... )
>>> shuffled = da.groupby("x").shuffle()
>>> shuffled.quantile(q=0.5).compute()
>>> shuffled = da.groupby("x")
>>> shuffled.groupby("x").quantile(q=0.5).compute()
<xarray.DataArray 'a' (x: 4)> Size: 32B
array([9., 3., 4., 5.])
Coordinates:
Expand All @@ -725,17 +725,7 @@ def distributed_shuffle(self, chunks: T_Chunks = None):
dask.array.shuffle
"""
self._raise_if_by_is_chunked()
new_groupers = {
# Using group.name handles the BinGrouper case
# It does *not* handle the TimeResampler case,
# so we just override this method in Resample
grouper.group.name: grouper.grouper.reset()
for grouper in self.groupers
}
return self._shuffle_obj(chunks).groupby(
new_groupers,
restore_coord_dims=self._restore_coord_dims,
)
return self._shuffle_obj(chunks)

def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
from xarray.core.dataarray import DataArray
Expand Down
74 changes: 56 additions & 18 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None
ds["variable"] = ds["variable"].chunk(chunk)
grouped = ds.groupby(ds.id)
if shuffle:
grouped = grouped.distributed_shuffle()
grouped = grouped.distributed_shuffle().groupby(ds.id)

# non reduction operation
expected1 = ds.copy()
Expand Down Expand Up @@ -1418,7 +1418,7 @@ def test_groupby_reductions(
with raise_if_dask_computes():
grouped = array.groupby("abc")
if shuffle:
grouped = grouped.distributed_shuffle()
grouped = grouped.distributed_shuffle().groupby("abc")

with xr.set_options(use_flox=use_flox):
actual = getattr(grouped, method)(dim="y")
Expand Down Expand Up @@ -1687,13 +1687,16 @@ def test_groupby_bins(

with xr.set_options(use_flox=use_flox):
gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs)
shuffled = gb.distributed_shuffle().groupby_bins(
"dim_0", bins=bins, **cut_kwargs
)
actual = gb.sum()
assert_identical(expected, actual)
assert_identical(expected, gb.distributed_shuffle().sum())
assert_identical(expected, shuffled.sum())

actual = gb.map(lambda x: x.sum())
assert_identical(expected, actual)
assert_identical(expected, gb.distributed_shuffle().map(lambda x: x.sum()))
assert_identical(expected, shuffled.map(lambda x: x.sum()))

# make sure original array dims are unchanged
assert len(array.dim_0) == 4
Expand Down Expand Up @@ -1877,17 +1880,18 @@ def resample_as_pandas(array, *args, **kwargs):
array = DataArray(np.arange(10), [("time", times)])

rs = array.resample(time=resample_freq)
shuffled = rs.distributed_shuffle().resample(time=resample_freq)
actual = rs.mean()
expected = resample_as_pandas(array, resample_freq)
assert_identical(expected, actual)
assert_identical(expected, rs.distributed_shuffle().mean())
assert_identical(expected, shuffled.mean())

assert_identical(expected, rs.reduce(np.mean))
assert_identical(expected, rs.distributed_shuffle().reduce(np.mean))
assert_identical(expected, shuffled.reduce(np.mean))

rs = array.resample(time="24h", closed="right")
actual = rs.mean()
shuffled = rs.distributed_shuffle().mean()
shuffled = rs.distributed_shuffle().resample(time="24h", closed="right").mean()
expected = resample_as_pandas(array, "24h", closed="right")
assert_identical(expected, actual)
assert_identical(expected, shuffled)
Expand Down Expand Up @@ -2830,9 +2834,11 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
name="foo",
)

gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper())
groupers: dict[str, Grouper]
groupers = dict(labels1=UniqueGrouper(), labels2=UniqueGrouper())
gb = da.groupby(groupers)
if shuffle:
gb = gb.distributed_shuffle()
gb = gb.distributed_shuffle().groupby(groupers)
repr(gb)

expected = DataArray(
Expand All @@ -2851,9 +2857,10 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
# -------
coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])}
square = DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"])
gb = square.groupby(a=UniqueGrouper(), b=UniqueGrouper())
groupers = dict(a=UniqueGrouper(), b=UniqueGrouper())
gb = square.groupby(groupers)
if shuffle:
gb = gb.distributed_shuffle()
gb = gb.distributed_shuffle().groupby(groupers)
repr(gb)
with xr.set_options(use_flox=use_flox):
actual = gb.mean()
Expand Down Expand Up @@ -2883,9 +2890,10 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
with xr.set_options(use_flox=use_flox):
assert_identical(gb.mean("z"), b.mean("z"))

gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper())
groupers = dict(x=UniqueGrouper(), xy=UniqueGrouper())
gb = b.groupby(groupers)
if shuffle:
gb = gb.distributed_shuffle()
gb = gb.distributed_shuffle().groupby(groupers)
repr(gb)
with xr.set_options(use_flox=use_flox):
actual = gb.mean()
Expand Down Expand Up @@ -2937,9 +2945,12 @@ def test_multiple_groupers_mixed(use_flox: bool, shuffle: bool) -> None:
{"foo": (("x", "y"), np.arange(12).reshape((4, 3)))},
coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))},
)
gb = ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper())
groupers: dict[str, Grouper] = dict(
x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()
)
gb = ds.groupby(groupers)
if shuffle:
gb = gb.distributed_shuffle()
gb = gb.distributed_shuffle().groupby(groupers)
expected_data = np.array(
[
[[0.0, np.nan], [np.nan, 3.0]],
Expand Down Expand Up @@ -3168,20 +3179,47 @@ def test_groupby_multiple_bin_grouper_missing_groups():


@requires_dask_ge_2024_08_1
def test_shuffle_by_simple() -> None:
def test_shuffle_simple() -> None:
import dask

da = xr.DataArray(
dims="x",
data=dask.array.from_array([1, 2, 3, 4, 5, 6], chunks=2),
coords={"label": ("x", "a b c a b c".split(" "))},
)
actual = da.distributed_shuffle_by(label=UniqueGrouper())
actual = da.groupby(label=UniqueGrouper()).distributed_shuffle()
expected = da.isel(x=[0, 3, 1, 4, 2, 5])
assert_identical(actual, expected)

with pytest.raises(ValueError):
da.chunk(x=2, eagerly_load_group=False).distributed_shuffle_by("label")
da.chunk(x=2, eagerly_load_group=False).groupby("label").distributed_shuffle()


@requires_dask_ge_2024_08_1
@pytest.mark.parametrize(
"chunks, expected_chunks",
[
((1,), (1, 3, 3, 3)),
((10,), (10,)),
],
)
def test_shuffle_by(chunks, expected_chunks):
import dask.array

from xarray.groupers import UniqueGrouper

da = xr.DataArray(
dims="x",
data=dask.array.arange(10, chunks=chunks),
coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
name="a",
)
ds = da.to_dataset()

for obj in [ds, da]:
actual = obj.groupby(x=UniqueGrouper()).distributed_shuffle()
assert_identical(actual, obj.sortby("x"))
assert actual.chunksizes["x"] == expected_chunks


@requires_dask
Expand Down

0 comments on commit c77d7c5

Please sign in to comment.