Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xr.unify_chunks() top level method #5445

Merged
merged 25 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
0d20cde
Add top level method
malmans2 Jun 6, 2021
3902a38
move in computation.py
malmans2 Jun 6, 2021
334e38d
no need to enumerate
malmans2 Jun 7, 2021
4515208
Update xarray/core/computation.py
malmans2 Jun 11, 2021
e940f8c
Update xarray/core/computation.py
malmans2 Jun 11, 2021
6d1810b
Update xarray/core/computation.py
malmans2 Jun 11, 2021
f11496b
Update xarray/core/computation.py
malmans2 Jun 11, 2021
0a39035
Merge branch 'pydata:master' into add_unify_chunks
malmans2 Jun 11, 2021
b8cbaed
implement @crusaderky suggestions
malmans2 Jun 11, 2021
fde1ea1
import annotations
malmans2 Jun 11, 2021
96a6e61
test transposed dimensions
malmans2 Jun 11, 2021
38563f3
check transposed and raise mismatch error
malmans2 Jun 11, 2021
0090148
Update xarray/core/computation.py
malmans2 Jun 14, 2021
f4052ee
Update xarray/core/computation.py
malmans2 Jun 14, 2021
ccc60f4
Update xarray/core/computation.py
malmans2 Jun 14, 2021
4ff8051
Update xarray/core/computation.py
malmans2 Jun 14, 2021
3360ae8
@crusaderky suggestions
malmans2 Jun 14, 2021
a0e1820
Update xarray/tests/test_dask.py
malmans2 Jun 14, 2021
a6db6ec
Update xarray/tests/test_dask.py
malmans2 Jun 14, 2021
a248880
Update xarray/tests/test_dask.py
crusaderky Jun 14, 2021
1765c95
fix typos
malmans2 Jun 14, 2021
0d50c44
qqMerge branch 'add_unify_chunks' of https://github.com/malmans2/xarr…
malmans2 Jun 14, 2021
10a174a
use tuple for da chunks
malmans2 Jun 14, 2021
305d138
Update xarray/tests/test_dask.py
malmans2 Jun 14, 2021
0e42ace
Update xarray/tests/test_dask.py
crusaderky Jun 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Top-level functions
map_blocks
show_versions
set_options
unify_chunks

Dataset
=======
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ v0.18.3 (unreleased)

New Features
~~~~~~~~~~~~
- New top-level function :py:func:`unify_chunks`.
By `Mattia Almansi <https://github.com/malmans2>`_.
- Allow assigning values to a subset of a dataset using positional or label-based
indexing (:issue:`3015`, :pull:`5362`).
By `Matthias Göbel <https://github.com/matzegoebel>`_.
Expand Down
3 changes: 2 additions & 1 deletion xarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .core.alignment import align, broadcast
from .core.combine import combine_by_coords, combine_nested
from .core.common import ALL_DIMS, full_like, ones_like, zeros_like
from .core.computation import apply_ufunc, corr, cov, dot, polyval, where
from .core.computation import apply_ufunc, corr, cov, dot, polyval, unify_chunks, where
from .core.concat import concat
from .core.dataarray import DataArray
from .core.dataset import Dataset
Expand Down Expand Up @@ -74,6 +74,7 @@
"save_mfdataset",
"set_options",
"show_versions",
"unify_chunks",
"where",
"zeros_like",
# Classes
Expand Down
62 changes: 62 additions & 0 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Functions for applying functions that act on arrays to xarray's labeled data.
"""
from __future__ import annotations

import functools
import itertools
import operator
Expand All @@ -19,6 +21,7 @@
Optional,
Sequence,
Tuple,
TypeVar,
Union,
)

Expand All @@ -34,8 +37,11 @@

if TYPE_CHECKING:
from .coordinates import Coordinates # noqa
from .dataarray import DataArray
from .dataset import Dataset

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)

_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
_DEFAULT_NAME = utils.ReprObject("<default-name>")
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})
Expand Down Expand Up @@ -1721,3 +1727,59 @@ def _calc_idxminmax(
res.attrs = indx.attrs

return res


def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]:
"""
Given any number of Dataset and/or DataArray objects, returns
new objects with unified chunk size along all chunked dimensions.

Returns
-------
unified (DataArray or Dataset) – Tuple of objects with the same type as
*objects with consistent chunk sizes for all dask-array variables

See Also
--------
dask.array.core.unify_chunks
"""
from .dataarray import DataArray

# Convert all objects to datasets
datasets = [
obj._to_temp_dataset() if isinstance(obj, DataArray) else obj.copy()
for obj in objects
]

# Get argumets to pass into dask.array.core.unify_chunks
unify_chunks_args = []
sizes: dict = {}
malmans2 marked this conversation as resolved.
Show resolved Hide resolved
for ds in datasets:
for v in ds._variables.values():
if v.chunks is not None:
# Check sizes
malmans2 marked this conversation as resolved.
Show resolved Hide resolved
sizes = {**ds.sizes, **sizes}
if not all(item in sizes.items() for item in ds.sizes.items()):
raise ValueError(
"Dimension lenghts are not consistent across chunked objects."
)
malmans2 marked this conversation as resolved.
Show resolved Hide resolved
# Append
malmans2 marked this conversation as resolved.
Show resolved Hide resolved
unify_chunks_args += [v._data, v._dims]

# No dask arrays: Return inputs
if not unify_chunks_args:
return objects

# Run dask.array.core.unify_chunks
from dask.array.core import unify_chunks

_, dask_data = unify_chunks(*unify_chunks_args)
dask_data_iter = iter(dask_data)
out = []
for obj, ds in zip(objects, datasets):
for k, v in ds._variables.items():
if v.chunks is not None:
ds._variables[k] = v.copy(data=next(dask_data_iter))
out.append(obj._from_temp_dataset(ds) if isinstance(obj, DataArray) else ds)

return tuple(out)
5 changes: 3 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from .arithmetic import DataArrayArithmetic
from .common import AbstractArray, DataWithCoords
from .computation import unify_chunks
from .coordinates import (
DataArrayCoordinates,
assert_coordinate_consistent,
Expand Down Expand Up @@ -3686,8 +3687,8 @@ def unify_chunks(self) -> "DataArray":
--------
dask.array.core.unify_chunks
"""
ds = self._to_temp_dataset().unify_chunks()
return self._from_temp_dataset(ds)

return unify_chunks(self)[0]

def map_blocks(
self,
Expand Down
33 changes: 2 additions & 31 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align
from .arithmetic import DatasetArithmetic
from .common import DataWithCoords, _contains_datetime_like_objects
from .computation import unify_chunks
from .coordinates import (
DatasetCoordinates,
assert_coordinate_consistent,
Expand Down Expand Up @@ -6563,37 +6564,7 @@ def unify_chunks(self) -> "Dataset":
dask.array.core.unify_chunks
"""

try:
self.chunks
except ValueError: # "inconsistent chunks"
pass
else:
# No variables with dask backend, or all chunks are already aligned
return self.copy()

# import dask is placed after the quick exit test above to allow
# running this method if dask isn't installed and there are no chunks
import dask.array

ds = self.copy()

dims_pos_map = {dim: index for index, dim in enumerate(ds.dims)}

dask_array_names = []
dask_unify_args = []
for name, variable in ds.variables.items():
if isinstance(variable.data, dask.array.Array):
dims_tuple = [dims_pos_map[dim] for dim in variable.dims]
dask_array_names.append(name)
dask_unify_args.append(variable.data)
dask_unify_args.append(dims_tuple)

_, rechunked_arrays = dask.array.core.unify_chunks(*dask_unify_args)

for name, new_array in zip(dask_array_names, rechunked_arrays):
ds.variables[name]._data = new_array

return ds
return unify_chunks(self)[0]

def map_blocks(
self,
Expand Down
25 changes: 23 additions & 2 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,12 +1069,33 @@ def test_unify_chunks(map_ds):
with pytest.raises(ValueError, match=r"inconsistent chunks"):
ds_copy.chunks

expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5), "z": (4,)}
expected_chunks = {"x": (4, 4, 2), "y": (5, 5, 5, 5)}
with raise_if_dask_computes():
actual_chunks = ds_copy.unify_chunks().chunks
expected_chunks == actual_chunks
assert expected_chunks == actual_chunks
crusaderky marked this conversation as resolved.
Show resolved Hide resolved
assert_identical(map_ds, ds_copy.unify_chunks())

actual = [
obj.chunks for obj in xr.unify_chunks(ds_copy.cxy, ds_copy.drop_vars("cxy"))
]
expected = [tuple(expected_chunks.values()), expected_chunks]
assert expected == actual
malmans2 marked this conversation as resolved.
Show resolved Hide resolved

crusaderky marked this conversation as resolved.
Show resolved Hide resolved
# Test unordered dims
da = ds_copy["cxy"]
da_transposed = da.transpose(*da.dims[::-1])
unified = xr.unify_chunks(da.chunk({"x": -1}), da_transposed.chunk({"y": -1}))
assert (
unified[0].chunks == unified[1].chunks[::-1] == tuple(expected_chunks.values())
)
malmans2 marked this conversation as resolved.
Show resolved Hide resolved

# Test mismatch
with pytest.raises(
ValueError,
match=r"Dimension lenghts are not consistent across chunked objects.",
malmans2 marked this conversation as resolved.
Show resolved Hide resolved
):
xr.unify_chunks(da, da.isel(x=slice(2)))


@pytest.mark.parametrize("obj", [make_ds(), make_da()])
@pytest.mark.parametrize(
Expand Down