diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b22a7217568..1cf9780492e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -70,6 +70,8 @@ New Features - :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases where the result of a computation could not be inferred automatically. By `Deepak Cherian `_ +- :py:meth:`map_blocks` can now handle dask-backed xarray objects in ``args``. (:pull:`3818`) + By `Deepak Cherian `_ Bug fixes ~~~~~~~~~ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 236938bac74..3451ff14c8f 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3262,45 +3262,91 @@ def map_blocks( ---------- func: callable User-provided function that accepts a DataArray as its first - parameter. The function will receive a subset, i.e. one block, of this DataArray - (see below), corresponding to one chunk along each chunked dimension. ``func`` will be - executed as ``func(block_subset, *args, **kwargs)``. + parameter. The function will receive a subset or 'block' of this DataArray (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(subset_dataarray, *subset_args, **kwargs)``. This function must return either a single DataArray or a single Dataset. This function cannot add a new chunked dimension. + + obj: DataArray, Dataset + Passed to the function as its first argument, one block at a time. args: Sequence - Passed verbatim to func after unpacking, after the sliced DataArray. xarray - objects, if any, will not be split by chunks. Passing dask collections is - not allowed. + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with obj, otherwise an error is raised. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections is not allowed. + subset to blocks. Passing dask collections in kwargs is not allowed. template: (optional) DataArray, Dataset xarray object representing the final result after compute is called. If not provided, - the function will be first run on mocked-up data, that looks like 'obj' but + the function will be first run on mocked-up data, that looks like ``obj`` but has sizes 0, to determine properties of the returned object such as dtype, - variable names, new dimensions and new indexes (if any). - 'template' must be provided if the function changes the size of existing dimensions. + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. + Returns ------- - A single DataArray or Dataset with dask backend, reassembled from the outputs of - the function. + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. Notes ----- - This method is designed for when one needs to manipulate a whole xarray object - within each chunk. In the more common case where one can work on numpy arrays, - it is recommended to use apply_ufunc. + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. In the more common case where ``func`` can work on numpy arrays, it is + recommended to use ``apply_ufunc``. - If none of the variables in this DataArray is backed by dask, calling this - method is equivalent to calling ``func(self, *args, **kwargs)``. + If none of the variables in ``obj`` is backed by dask arrays, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, - xarray.Dataset.map_blocks + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, + xarray.DataArray.map_blocks + + Examples + -------- + + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type="time.month"): + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim="time") + ... return gb - clim + >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> np.random.seed(123) + >>> array = xr.DataArray( + ... np.random.rand(len(time)), dims="time", coords=[time] + ... ).chunk() + >>> array.map_blocks(calculate_anomaly, template=array).compute() + + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> array.map_blocks( + ... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array, + ... ) + + array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , + -0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425, + -0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273, + 0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 , + 0.14482397, 0.35985481, 0.23487834, 0.12144652]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ from .parallel import map_blocks diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3a55f3eca27..29cecae55b0 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5721,45 +5721,92 @@ def map_blocks( ---------- func: callable User-provided function that accepts a Dataset as its first - parameter. The function will receive a subset, i.e. one block, of this Dataset - (see below), corresponding to one chunk along each chunked dimension. ``func`` will be - executed as ``func(block_subset, *args, **kwargs)``. + parameter. The function will receive a subset or 'block' of this Dataset (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(subset_dataset, *subset_args, **kwargs)``. This function must return either a single DataArray or a single Dataset. This function cannot add a new chunked dimension. + + obj: DataArray, Dataset + Passed to the function as its first argument, one block at a time. args: Sequence - Passed verbatim to func after unpacking, after the sliced DataArray. xarray - objects, if any, will not be split by chunks. Passing dask collections is - not allowed. + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with obj, otherwise an error is raised. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections is not allowed. + subset to blocks. Passing dask collections in kwargs is not allowed. template: (optional) DataArray, Dataset xarray object representing the final result after compute is called. If not provided, - the function will be first run on mocked-up data, that looks like 'obj' but + the function will be first run on mocked-up data, that looks like ``obj`` but has sizes 0, to determine properties of the returned object such as dtype, - variable names, new dimensions and new indexes (if any). - 'template' must be provided if the function changes the size of existing dimensions. + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. + Returns ------- - A single DataArray or Dataset with dask backend, reassembled from the outputs of - the function. + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. Notes ----- - This method is designed for when one needs to manipulate a whole xarray object - within each chunk. In the more common case where one can work on numpy arrays, - it is recommended to use apply_ufunc. + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. In the more common case where ``func`` can work on numpy arrays, it is + recommended to use ``apply_ufunc``. - If none of the variables in this Dataset is backed by dask, calling this method - is equivalent to calling ``func(self, *args, **kwargs)``. + If none of the variables in ``obj`` is backed by dask arrays, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, xarray.DataArray.map_blocks + + Examples + -------- + + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type="time.month"): + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim="time") + ... return gb - clim + >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> np.random.seed(123) + >>> array = xr.DataArray( + ... np.random.rand(len(time)), dims="time", coords=[time] + ... ).chunk() + >>> ds = xr.Dataset({"a": array}) + >>> ds.map_blocks(calculate_anomaly, template=ds).compute() + + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> ds.map_blocks( + ... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=ds, + ... ) + + array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , + -0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425, + -0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273, + 0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 , + 0.14482397, 0.35985481, 0.23487834, 0.12144652]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ from .parallel import map_blocks diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index d91dfb4a275..522c5b36ff5 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -16,6 +16,8 @@ DefaultDict, Dict, Hashable, + Iterable, + List, Mapping, Sequence, Tuple, @@ -25,12 +27,29 @@ import numpy as np +from .alignment import align from .dataarray import DataArray from .dataset import Dataset T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) +def to_object_array(iterable): + # using empty_like calls compute + npargs = np.empty((len(iterable),), dtype=np.object) + npargs[:] = iterable + return npargs + + +def assert_chunks_compatible(a: Dataset, b: Dataset): + a = a.unify_chunks() + b = b.unify_chunks() + + for dim in set(a.chunks).intersection(set(b.chunks)): + if a.chunks[dim] != b.chunks[dim]: + raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.") + + def check_result_variables( result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str ): @@ -67,6 +86,17 @@ def dataset_to_dataarray(obj: Dataset) -> DataArray: return next(iter(obj.data_vars.values())) +def dataarray_to_dataset(obj: DataArray) -> Dataset: + # only using _to_temp_dataset would break + # func = lambda x: x.to_dataset() + # since that relies on preserving name. + if obj.name is None: + dataset = obj._to_temp_dataset() + else: + dataset = obj.to_dataset() + return dataset + + def make_meta(obj): """If obj is a DataArray or Dataset, return a new object of the same type and with the same variables and dtypes, but where all variables have size 0 and numpy @@ -150,30 +180,30 @@ def map_blocks( ---------- func: callable User-provided function that accepts a DataArray or Dataset as its first - parameter. The function will receive a subset of 'obj' (see below), + parameter ``obj``. The function will receive a subset or 'block' of ``obj`` (see below), corresponding to one chunk along each chunked dimension. ``func`` will be - executed as ``func(obj_subset, *args, **kwargs)``. + executed as ``func(subset_obj, *subset_args, **kwargs)``. This function must return either a single DataArray or a single Dataset. This function cannot add a new chunked dimension. obj: DataArray, Dataset - Passed to the function as its first argument, one dask chunk at a time. + Passed to the function as its first argument, one block at a time. args: Sequence - Passed verbatim to func after unpacking, after the sliced obj. xarray objects, - if any, will not be split by chunks. Passing dask collections is not allowed. + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with obj, otherwise an error is raised. kwargs: Mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections is not allowed. + subset to blocks. Passing dask collections in kwargs is not allowed. template: (optional) DataArray, Dataset xarray object representing the final result after compute is called. If not provided, - the function will be first run on mocked-up data, that looks like 'obj' but + the function will be first run on mocked-up data, that looks like ``obj`` but has sizes 0, to determine properties of the returned object such as dtype, variable names, attributes, new dimensions and new indexes (if any). - 'template' must be provided if the function changes the size of existing dimensions. - When provided, `attrs` on variables in `template` are copied over to the result. Any - `attrs` set by `func` will be ignored. + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. Returns @@ -183,11 +213,11 @@ def map_blocks( Notes ----- - This function is designed for when one needs to manipulate a whole xarray object - within each chunk. In the more common case where one can work on numpy arrays, it is - recommended to use apply_ufunc. + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. In the more common case where ``func`` can work on numpy arrays, it is + recommended to use ``apply_ufunc``. - If none of the variables in obj is backed by dask, calling this function is + If none of the variables in ``obj`` is backed by dask arrays, calling this function is equivalent to calling ``func(obj, *args, **kwargs)``. See Also @@ -203,10 +233,6 @@ def map_blocks( its indices, and its methods like ``.groupby()``. >>> def calculate_anomaly(da, groupby_type="time.month"): - ... # Necessary workaround to xarray's check with zero dimensions - ... # https://github.com/pydata/xarray/issues/3575 - ... if sum(da.shape) == 0: - ... return da ... gb = da.groupby(groupby_type) ... clim = gb.mean(dim="time") ... return gb - clim @@ -215,7 +241,7 @@ def map_blocks( >>> array = xr.DataArray( ... np.random.rand(len(time)), dims="time", coords=[time] ... ).chunk() - >>> xr.map_blocks(calculate_anomaly, array).compute() + >>> xr.map_blocks(calculate_anomaly, array, template=array).compute() array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, @@ -229,7 +255,7 @@ def map_blocks( to the function being applied in ``xr.map_blocks()``: >>> xr.map_blocks( - ... calculate_anomaly, array, kwargs={"groupby_type": "time.year"}, + ... calculate_anomaly, array, kwargs={"groupby_type": "time.year"}, template=array, ... ) array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , @@ -241,14 +267,24 @@ def map_blocks( * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 """ - def _wrapper(func, obj, to_array, args, kwargs, expected): - check_shapes = dict(obj.dims) - check_shapes.update(expected["shapes"]) - - if to_array: - obj = dataset_to_dataarray(obj) - - result = func(obj, *args, **kwargs) + def _wrapper( + func: Callable, + args: List, + kwargs: dict, + arg_is_array: Iterable[bool], + expected: dict, + ): + """ + Wrapper function that receives datasets in args; converts to dataarrays when necessary; + passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. + """ + + converted_args = [ + dataset_to_dataarray(arg) if is_array else arg + for is_array, arg in zip(arg_is_array, args) + ] + + result = func(*converted_args, **kwargs) # check all dims are present missing_dimensions = set(expected["shapes"]) - set(result.sizes) @@ -259,10 +295,10 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): # check that index lengths and values are as expected for name, index in result.indexes.items(): - if name in check_shapes: - if len(index) != check_shapes[name]: + if name in expected["shapes"]: + if len(index) != expected["shapes"][name]: raise ValueError( - f"Received dimension {name!r} of length {len(index)}. Expected length {check_shapes[name]}." + f"Received dimension {name!r} of length {len(index)}. Expected length {expected['shapes'][name]}." ) if name in expected["indexes"]: expected_index = expected["indexes"][name] @@ -289,38 +325,44 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): elif not isinstance(kwargs, Mapping): raise TypeError("kwargs must be a mapping (for example, a dict)") - for value in list(args) + list(kwargs.values()): + for value in kwargs.values(): if dask.is_dask_collection(value): raise TypeError( - "Cannot pass dask collections in args or kwargs yet. Please compute or " + "Cannot pass dask collections in kwargs yet. Please compute or " "load values before passing to map_blocks." ) if not dask.is_dask_collection(obj): return func(obj, *args, **kwargs) - if isinstance(obj, DataArray): - # only using _to_temp_dataset would break - # func = lambda x: x.to_dataset() - # since that relies on preserving name. - if obj.name is None: - dataset = obj._to_temp_dataset() - else: - dataset = obj.to_dataset() - input_is_array = True - else: - dataset = obj - input_is_array = False + npargs = to_object_array([obj] + list(args)) + is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs] + is_array = [isinstance(arg, DataArray) for arg in npargs] + + # all xarray objects must be aligned. This is consistent with apply_ufunc. + aligned = align(*npargs[is_xarray], join="exact") + # assigning to object arrays works better when RHS is object array + # https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array + npargs[is_xarray] = to_object_array(aligned) + npargs[is_array] = to_object_array( + [dataarray_to_dataset(da) for da in npargs[is_array]] + ) + + # check that chunk sizes are compatible + input_chunks = dict(npargs[0].chunks) + input_indexes = dict(npargs[0].indexes) + for arg in npargs[1:][is_xarray[1:]]: + assert_chunks_compatible(npargs[0], arg) + input_chunks.update(arg.chunks) + input_indexes.update(arg.indexes) - input_chunks = dataset.chunks - dataset_indexes = set(dataset.indexes) if template is None: # infer template by providing zero-shaped arrays - template = infer_template(func, obj, *args, **kwargs) + template = infer_template(func, aligned[0], *args, **kwargs) template_indexes = set(template.indexes) - preserved_indexes = template_indexes & dataset_indexes - new_indexes = template_indexes - dataset_indexes - indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} + preserved_indexes = template_indexes & set(input_indexes) + new_indexes = template_indexes - set(input_indexes) + indexes = {dim: input_indexes[dim] for dim in preserved_indexes} indexes.update({k: template.indexes[k] for k in new_indexes}) output_chunks = { dim: input_chunks[dim] for dim in template.dims if dim in input_chunks @@ -328,13 +370,11 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): else: # template xarray object has been provided with proper sizes and chunk shapes - template_indexes = set(template.indexes) - indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes} - indexes.update({k: template.indexes[k] for k in template_indexes}) + indexes = dict(template.indexes) if isinstance(template, DataArray): output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore else: - output_chunks = template.chunks # type: ignore + output_chunks = dict(template.chunks) for dim in output_chunks: if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): @@ -363,7 +403,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): graph: Dict[Any, Any] = {} new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) gname = "{}-{}".format( - dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs) + dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs) ) # map dims to list of chunk indexes @@ -376,9 +416,14 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() } - # iterate over all possible chunk combinations - for v in itertools.product(*ichunk.values()): - chunk_index = dict(zip(dataset.dims, v)) + def subset_dataset_to_block( + graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index + ): + """ + Creates a task that subsets an xarray dataset to a block determined by chunk_index. + Block extents are determined by input_chunk_bounds. + Also subtasks that subset the constituent variables of a dataset. + """ # this will become [[name1, variable1], # [name2, variable2], @@ -387,6 +432,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): data_vars = [] coords = [] + chunk_tuple = tuple(chunk_index.values()) for name, variable in dataset.variables.items(): # make a task that creates tuple of (dims, chunk) if dask.is_dask_collection(variable.data): @@ -395,13 +441,13 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): for dim in variable.dims: chunk = chunk[chunk_index[dim]] - chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + v + chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + chunk_tuple graph[chunk_variable_task] = ( tuple, [variable.dims, chunk, variable.attrs], ) else: - # non-dask array with possibly chunked dimensions + # non-dask array possibly with dimensions chunked on other variables # index into variable appropriately subsetter = { dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) @@ -410,7 +456,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): subset = variable.isel(subsetter) chunk_variable_task = ( "{}-{}".format(gname, dask.base.tokenize(subset)), - ) + v + ) + chunk_tuple graph[chunk_variable_task] = ( tuple, [subset.dims, subset, subset.attrs], @@ -422,7 +468,22 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): else: data_vars.append([name, chunk_variable_task]) - # expected["shapes", "coords", "data_vars", "indexes"] are used to raise nice error messages in _wrapper + return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + + # iterate over all possible chunk combinations + for chunk_tuple in itertools.product(*ichunk.values()): + # mapping from dimension name to chunk index + chunk_index = dict(zip(ichunk.keys(), chunk_tuple)) + + blocked_args = [ + subset_dataset_to_block(graph, gname, arg, input_chunk_bounds, chunk_index) + if isxr + else arg + for isxr, arg in zip(is_xarray, npargs) + ] + + # expected["shapes", "coords", "data_vars", "indexes"] are used to + # raise nice error messages in _wrapper expected = {} # input chunk 0 along a dimension maps to output chunk 0 along the same dimension # even if length of dimension is changed by the applied function @@ -436,16 +497,8 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): for dim in indexes } - from_wrapper = (gname,) + v - graph[from_wrapper] = ( - _wrapper, - func, - (Dataset, (dict, data_vars), (dict, coords), dataset.attrs), - input_is_array, - args, - kwargs, - expected, - ) + from_wrapper = (gname,) + chunk_tuple + graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) # mapping from variable name to dask graph key var_key_map: Dict[Hashable, str] = {} @@ -472,7 +525,11 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): # layer. new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) - hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) + hlg = HighLevelGraph.from_collections( + gname, + graph, + dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)], + ) for gname_l, layer in new_layers.items(): # This adds in the getitems for each variable in the dataset. @@ -480,6 +537,10 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): hlg.layers[gname_l] = layer result = Dataset(coords=indexes, attrs=template.attrs) + for index in result.indexes: + result[index].attrs = template[index].attrs + result[index].encoding = template[index].encoding + for name, gname_l in var_key_map.items(): dims = template[name].dims var_chunks = [] @@ -496,6 +557,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected): hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype ) result[name] = (dims, data, template[name].attrs) + result[name].encoding = template[name].encoding result = result.set_coords(template._coord_names) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 75beb3757ca..eb06336d296 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -972,6 +972,7 @@ def make_da(): coords={"x": np.arange(10), "y": np.arange(100, 120)}, name="a", ).chunk({"x": 4, "y": 5}) + da.x.attrs["long_name"] = "x" da.attrs["test"] = "test" da.coords["c2"] = 0.5 da.coords["ndcoord"] = da.x * 2 @@ -995,6 +996,9 @@ def make_ds(): map_ds.attrs["test"] = "test" map_ds.coords["xx"] = map_ds["a"] * map_ds.y + map_ds.x.attrs["long_name"] = "x" + map_ds.y.attrs["long_name"] = "y" + return map_ds @@ -1066,9 +1070,6 @@ def really_bad_func(darray): with raises_regex(ValueError, "inconsistent chunks"): xr.map_blocks(bad_func, ds_copy) - with raises_regex(TypeError, "Cannot pass dask collections"): - xr.map_blocks(bad_func, map_da, args=[map_da.chunk()]) - with raises_regex(TypeError, "Cannot pass dask collections"): xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk())) @@ -1095,6 +1096,58 @@ def test_map_blocks_convert_args_to_list(obj): assert_identical(actual, expected) +def test_map_blocks_dask_args(): + da1 = xr.DataArray( + np.ones((10, 20)), + dims=["x", "y"], + coords={"x": np.arange(10), "y": np.arange(20)}, + ).chunk({"x": 5, "y": 4}) + + # check that block shapes are the same + def sumda(da1, da2): + assert da1.shape == da2.shape + return da1 + da2 + + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks(sumda, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # one dimension in common + da2 = (da1 + 1).isel(x=1, drop=True) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # test that everything works when dimension names are different + da2 = (da1 + 1).isel(x=1, drop=True).rename({"y": "k"}) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + with raises_regex(ValueError, "Chunk sizes along dimension 'x'"): + xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) + + with raises_regex(ValueError, "indexes along dimension 'x' are not equal"): + xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) + + # reduction + da1 = da1.chunk({"x": -1}) + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks(lambda a, b: (a + b).sum("x"), da1, args=[da2]) + xr.testing.assert_equal((da1 + da2).sum("x"), mapped) + + # reduction with template + da1 = da1.chunk({"x": -1}) + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks( + lambda a, b: (a + b).sum("x"), da1, args=[da2], template=da1.sum("x") + ) + xr.testing.assert_equal((da1 + da2).sum("x"), mapped) + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj): def add_attrs(obj):