Skip to content
forked from pydata/xarray

Commit

Permalink
Make tests functional by call compute before assert_equal
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Oct 3, 2019
1 parent f6dfb12 commit c73eda1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
18 changes: 15 additions & 3 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,13 @@ def _wrapper(func, obj, to_array, args, kwargs):
return func(obj, *args, **kwargs)

if isinstance(obj, DataArray):
dataset = obj._to_temp_dataset()
# only using _to_temp_dataset would break
# func = lambda x: x.to_dataset()
# since that relies on preserving name.
if obj.name is None:
dataset = obj._to_temp_dataset()
else:
dataset = obj.to_dataset()
input_is_array = True
else:
dataset = obj
Expand All @@ -205,7 +211,13 @@ def _wrapper(func, obj, to_array, args, kwargs):
"func output must be DataArray or Dataset; got %s" % type(template)
)

indexes = {dim: dataset.indexes[dim] for dim in template.dims}
template_indexes = set(template.indexes)
dataset_indexes = set(dataset.indexes)
preserved_indexes = template_indexes.intersection(dataset_indexes)
new_indexes = set(template_indexes) - set(dataset_indexes)
indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes}
indexes.update({k: template.indexes[k] for k in new_indexes})

graph = {} # type: Dict[Any, Any]
gname = "%s-%s" % (dask.utils.funcname(func), dask.base.tokenize(dataset))

Expand Down Expand Up @@ -287,7 +299,7 @@ def _wrapper(func, obj, to_array, args, kwargs):

graph[key] = (operator.getitem, from_wrapper, name)

graph = HighLevelGraph.from_collections(name, graph, dependencies=[dataset])
graph = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset])

result = Dataset(coords=indexes)
for name, gname_l in var_key_map.items():
Expand Down
17 changes: 11 additions & 6 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ def func(obj):
actual = xr.map_blocks(func, obj)
expected = func(obj)
assert_chunks_equal(expected.chunk(), actual)
xr.testing.assert_equal(actual, expected)
xr.testing.assert_equal(actual.compute(), expected.compute())


@pytest.mark.parametrize("obj", [map_da, map_ds])
Expand All @@ -1031,7 +1031,7 @@ def test_map_blocks_convert_args_to_list(obj):
with raise_if_dask_computes():
actual = xr.map_blocks(operator.add, obj, [10])
assert_chunks_equal(expected.chunk(), actual)
xr.testing.assert_equal(actual, expected)
xr.testing.assert_equal(actual.compute(), expected.compute())


@pytest.mark.parametrize("obj", [map_da, map_ds])
Expand All @@ -1040,27 +1040,32 @@ def test_map_blocks_kwargs(obj):
with raise_if_dask_computes():
actual = xr.map_blocks(xr.full_like, obj, kwargs=dict(fill_value=np.nan))
assert_chunks_equal(expected.chunk(), actual)
xr.testing.assert_equal(actual, expected)
xr.testing.assert_equal(actual.compute(), expected.compute())


@pytest.mark.parametrize(
"func, obj",
[
[lambda x: x.to_dataset(), map_da],
[lambda x: x.to_array(), map_ds],
[lambda x: x.drop("cxy"), map_ds],
[lambda x: x.drop("a"), map_ds],
[lambda x: x.drop("x"), map_da],
[lambda x: x.drop("x"), map_ds],
[lambda x: x.expand_dims(k=[1, 2, 3]), map_ds],
[lambda x: x.expand_dims(k=[1, 2, 3]), map_da],
[lambda x: x.isel(x=1), map_ds],
[lambda x: x.isel(x=1).drop("x"), map_da],
# TODO: [lambda x: x.isel(x=1), map_ds],
# TODO: [lambda x: x.isel(x=1).drop("x"), map_da],
[lambda x: x.assign_coords(new_coord=("y", x.y * 2)), map_da],
[lambda x: x.astype(np.int32), map_da],
[lambda x: x.rename({"a": "new1", "b": "new2"}), map_ds],
],
)
def test_map_blocks_transformations(func, obj):
with raise_if_dask_computes():
assert_equal(xr.map_blocks(func, obj), func(obj))
actual = xr.map_blocks(func, obj)

assert_equal(actual.compute(), func(obj).compute())


@pytest.mark.parametrize("obj", [map_da, map_ds])
Expand Down

0 comments on commit c73eda1

Please sign in to comment.