Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map_blocks: Allow passing dask-backed objects in args #3818

Merged
merged 23 commits into from
Jun 7, 2020
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ New Features
- :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases
where the result of a computation could not be inferred automatically.
By `Deepak Cherian <https://github.com/dcherian>`_
- :py:meth:`map_blocks` can now handle dask-backed xarray objects in ``args``. (:pull:`3818`)
By `Deepak Cherian <https://github.com/dcherian>`_

Bug fixes
~~~~~~~~~
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3270,9 +3270,8 @@ def map_blocks(

This function cannot add a new chunked dimension.
args: Sequence
Passed verbatim to func after unpacking, after the sliced DataArray. xarray
objects, if any, will not be split by chunks. Passing dask collections is
not allowed.
Passed verbatim to func after unpacking, after the sliced obj.
dcherian marked this conversation as resolved.
Show resolved Hide resolved
Any xarray objects will also be split by blocks and then passed on.
kwargs: Mapping
Passed verbatim to func after unpacking. xarray objects, if any, will not be
split by chunks. Passing dask collections is not allowed.
Expand Down
5 changes: 2 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5729,9 +5729,8 @@ def map_blocks(

This function cannot add a new chunked dimension.
args: Sequence
Passed verbatim to func after unpacking, after the sliced DataArray. xarray
objects, if any, will not be split by chunks. Passing dask collections is
not allowed.
Passed verbatim to func after unpacking, after the sliced obj.
Any xarray objects will also be split by blocks and then passed on.
kwargs: Mapping
Passed verbatim to func after unpacking. xarray objects, if any, will not be
split by chunks. Passing dask collections is not allowed.
Expand Down
174 changes: 122 additions & 52 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
DefaultDict,
Dict,
Hashable,
Iterable,
List,
Mapping,
Sequence,
Tuple,
Expand All @@ -25,12 +27,29 @@

import numpy as np

from .alignment import align
from .dataarray import DataArray
from .dataset import Dataset

T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)


def to_object_array(iterable):
npargs = np.empty((len(iterable),), dtype=np.object)
dcherian marked this conversation as resolved.
Show resolved Hide resolved
for idx, item in enumerate(iterable):
npargs[idx] = item
dcherian marked this conversation as resolved.
Show resolved Hide resolved
return npargs


def assert_chunks_compatible(a: Dataset, b: Dataset):
a = a.unify_chunks()
b = b.unify_chunks()

for dim in set(a.chunks).intersection(set(b.chunks)):
if a.chunks[dim] != b.chunks[dim]:
raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.")


def check_result_variables(
result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str
):
Expand Down Expand Up @@ -67,6 +86,17 @@ def dataset_to_dataarray(obj: Dataset) -> DataArray:
return next(iter(obj.data_vars.values()))


def dataarray_to_dataset(obj: DataArray) -> Dataset:
# only using _to_temp_dataset would break
# func = lambda x: x.to_dataset()
# since that relies on preserving name.
if obj.name is None:
dataset = obj._to_temp_dataset()
else:
dataset = obj.to_dataset()
return dataset


def make_meta(obj):
"""If obj is a DataArray or Dataset, return a new object of the same type and with
the same variables and dtypes, but where all variables have size 0 and numpy
Expand Down Expand Up @@ -161,8 +191,8 @@ def map_blocks(
obj: DataArray, Dataset
Passed to the function as its first argument, one dask chunk at a time.
args: Sequence
Passed verbatim to func after unpacking, after the sliced obj. xarray objects,
if any, will not be split by chunks. Passing dask collections is not allowed.
Passed verbatim to func after unpacking, after the sliced obj.
Any xarray objects will also be split by blocks and then passed on.
kwargs: Mapping
Passed verbatim to func after unpacking. xarray objects, if any, will not be
split by chunks. Passing dask collections is not allowed.
Expand Down Expand Up @@ -241,14 +271,27 @@ def map_blocks(
* time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00
"""

def _wrapper(func, obj, to_array, args, kwargs, expected):
check_shapes = dict(obj.dims)
def _wrapper(
func: Callable,
args: List,
kwargs: dict,
arg_is_array: Iterable[bool],
expected: dict,
):
"""
Wrapper function that receives datasets in args; converts to dataarrays when necessary;
passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc.
"""

check_shapes = dict(args[0].dims)
check_shapes.update(expected["shapes"])

if to_array:
obj = dataset_to_dataarray(obj)
converted_args = [
dataset_to_dataarray(arg) if is_array else arg
for is_array, arg in zip(arg_is_array, args)
]

result = func(obj, *args, **kwargs)
result = func(*converted_args, **kwargs)

# check all dims are present
missing_dimensions = set(expected["shapes"]) - set(result.sizes)
Expand Down Expand Up @@ -289,52 +332,57 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
elif not isinstance(kwargs, Mapping):
raise TypeError("kwargs must be a mapping (for example, a dict)")

for value in list(args) + list(kwargs.values()):
for value in kwargs.values():
if dask.is_dask_collection(value):
raise TypeError(
"Cannot pass dask collections in args or kwargs yet. Please compute or "
"Cannot pass dask collections in kwargs yet. Please compute or "
"load values before passing to map_blocks."
)

if not dask.is_dask_collection(obj):
return func(obj, *args, **kwargs)

if isinstance(obj, DataArray):
# only using _to_temp_dataset would break
# func = lambda x: x.to_dataset()
# since that relies on preserving name.
if obj.name is None:
dataset = obj._to_temp_dataset()
else:
dataset = obj.to_dataset()
input_is_array = True
else:
dataset = obj
input_is_array = False
npargs = to_object_array([obj] + list(args))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converting to object array so that we can use boolean indexing to pull out xarray objects

is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in npargs]
is_array = [isinstance(arg, DataArray) for arg in npargs]

# align all xarray objects
# TODO: should we allow join as a kwarg or force everything to be aligned to the first object?
aligned = align(*npargs[is_xarray], join="left")
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# assigning to object arrays works better when RHS is object array
# https://stackoverflow.com/questions/43645135/boolean-indexing-assignment-of-a-numpy-array-to-a-numpy-array
npargs[is_xarray] = to_object_array(aligned)
Copy link
Contributor Author

@dcherian dcherian Mar 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to do this assignment?

np.array(args) ends up computing things.

npargs[is_array] = to_object_array(
[dataarray_to_dataset(da) for da in npargs[is_array]]
)

# check that chunk sizes are compatible
input_chunks = dict(npargs[0].chunks)
input_indexes = dict(npargs[0].indexes)
for arg in npargs[1:][is_xarray[1:]]:
assert_chunks_compatible(npargs[0], arg)
input_chunks.update(arg.chunks)
input_indexes.update(arg.indexes)

input_chunks = dataset.chunks
dataset_indexes = set(dataset.indexes)
if template is None:
# infer template by providing zero-shaped arrays
template = infer_template(func, obj, *args, **kwargs)
template = infer_template(func, aligned[0], *args, **kwargs)
template_indexes = set(template.indexes)
preserved_indexes = template_indexes & dataset_indexes
new_indexes = template_indexes - dataset_indexes
indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes}
preserved_indexes = template_indexes & set(input_indexes)
new_indexes = template_indexes - set(input_indexes)
indexes = {dim: input_indexes[dim] for dim in preserved_indexes}
indexes.update({k: template.indexes[k] for k in new_indexes})
output_chunks = {
dim: input_chunks[dim] for dim in template.dims if dim in input_chunks
}

else:
# template xarray object has been provided with proper sizes and chunk shapes
template_indexes = set(template.indexes)
indexes = {dim: dataset.indexes[dim] for dim in dataset_indexes}
indexes.update({k: template.indexes[k] for k in template_indexes})
indexes = dict(template.indexes)
if isinstance(template, DataArray):
output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore
else:
output_chunks = template.chunks # type: ignore
output_chunks = dict(template.chunks)

for dim in output_chunks:
if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]):
Expand Down Expand Up @@ -363,7 +411,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
graph: Dict[Any, Any] = {}
new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict)
gname = "{}-{}".format(
dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs)
dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs)
)

# map dims to list of chunk indexes
Expand All @@ -376,17 +424,23 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items()
}

# iterate over all possible chunk combinations
for v in itertools.product(*ichunk.values()):
chunk_index = dict(zip(dataset.dims, v))
def subset_dataset_to_block(
graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index
):
"""
Creates a task that creates a subsets xarray dataset to a block determined by chunk_index;
dcherian marked this conversation as resolved.
Show resolved Hide resolved
whose extents are determined by input_chunk_bounds.
There are subtasks that create subsets of constituent variables.
"""

# this will become [[name1, variable1],
# [name2, variable2],
# ...]
# [name2, variable2],
# ...]
dcherian marked this conversation as resolved.
Show resolved Hide resolved
# which is passed to dict and then to Dataset
data_vars = []
coords = []

chunk_tuple = tuple(chunk_index.values())
for name, variable in dataset.variables.items():
# make a task that creates tuple of (dims, chunk)
if dask.is_dask_collection(variable.data):
Expand All @@ -395,13 +449,13 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
for dim in variable.dims:
chunk = chunk[chunk_index[dim]]

chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + v
chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + chunk_tuple
graph[chunk_variable_task] = (
tuple,
[variable.dims, chunk, variable.attrs],
)
else:
# non-dask array with possibly chunked dimensions
# non-dask array possibly with dimensions chunked on other variables
# index into variable appropriately
subsetter = {
dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds)
Expand All @@ -410,7 +464,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
subset = variable.isel(subsetter)
chunk_variable_task = (
"{}-{}".format(gname, dask.base.tokenize(subset)),
) + v
) + chunk_tuple
graph[chunk_variable_task] = (
tuple,
[subset.dims, subset, subset.attrs],
Expand All @@ -422,7 +476,22 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
else:
data_vars.append([name, chunk_variable_task])

# expected["shapes", "coords", "data_vars", "indexes"] are used to raise nice error messages in _wrapper
return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs)

# iterate over all possible chunk combinations
for chunk_tuple in itertools.product(*ichunk.values()):
# mapping from dimension name to chunk index
chunk_index = dict(zip(ichunk.keys(), chunk_tuple))

blocked_args = [
subset_dataset_to_block(graph, gname, arg, input_chunk_bounds, chunk_index)
if isxr
else arg
for isxr, arg in zip(is_xarray, npargs)
]

# expected["shapes", "coords", "data_vars", "indexes"] are used to
# raise nice error messages in _wrapper
expected = {}
# input chunk 0 along a dimension maps to output chunk 0 along the same dimension
# even if length of dimension is changed by the applied function
Expand All @@ -436,16 +505,8 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
for dim in indexes
}

from_wrapper = (gname,) + v
graph[from_wrapper] = (
_wrapper,
func,
(Dataset, (dict, data_vars), (dict, coords), dataset.attrs),
input_is_array,
args,
kwargs,
expected,
)
from_wrapper = (gname,) + chunk_tuple
graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected)

# mapping from variable name to dask graph key
var_key_map: Dict[Hashable, str] = {}
Expand All @@ -472,14 +533,22 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
# layer.
new_layers[gname_l][key] = (operator.getitem, from_wrapper, name)

hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])
hlg = HighLevelGraph.from_collections(
gname,
graph,
dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)],
)

for gname_l, layer in new_layers.items():
# This adds in the getitems for each variable in the dataset.
hlg.dependencies[gname_l] = {gname}
hlg.layers[gname_l] = layer

result = Dataset(coords=indexes, attrs=template.attrs)
for index in result.indexes:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
result[index].attrs = template[index].attrs
result[index].encoding = template[index].encoding

for name, gname_l in var_key_map.items():
dims = template[name].dims
var_chunks = []
Expand All @@ -496,6 +565,7 @@ def _wrapper(func, obj, to_array, args, kwargs, expected):
hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype
)
result[name] = (dims, data, template[name].attrs)
result[name].encoding = template[name].encoding

result = result.set_coords(template._coord_names)

Expand Down
Loading