Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load nonindex coords ahead of concat() #1551

Merged
merged 13 commits into from
Oct 9, 2017
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ Bug fixes
``rtol`` arguments when called on ``DataArray`` objects.
By `Stephan Hoyer <https://github.com/shoyer>`_.

- Xarray ``quantile`` methods now properly raise a ``TypeError`` when applied to
- :py:func:`~xarray.concat` was computing multiple times coordinates that are
not index and not in memory (e.g. dask-based); :py:func:`~xarray.open_mfdataset`
was loading them multiple times from disk. Now, both functions will instead
load them once and store them as numpy arrays (:issue:`1521`).
By `Guido Imperiale <https://github.com/crusaderky>`_.

- xarray ``quantile`` methods now properly raise a ``TypeError`` when applied to
objects with data stored as ``dask`` arrays (:issue:`1529`).
By `Joe Hamman <https://github.com/jhamman>`_.

Expand Down Expand Up @@ -2032,4 +2038,4 @@ Miles.
v0.1 (2 May 2014)
-----------------

Initial release.
Initial release.
13 changes: 13 additions & 0 deletions xarray/core/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,17 @@ def differs(vname):
return concat_over


def _load_coords(dataset):
"""Load into memory any non-index coords. Preserve original.
"""
if all(coord._in_memory for coord in dataset.coords.values()):
return dataset
dataset = dataset.copy()
for coord in dataset.coords.values():
coord.load()
return dataset


def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
"""
Concatenate a sequence of datasets along a new or existing dimension
Expand All @@ -208,6 +219,8 @@ def _dataset_concat(datasets, dim, data_vars, coords, compat, positions):
dim, coord = _calc_concat_dim_coord(dim)
datasets = [as_dataset(ds) for ds in datasets]
datasets = align(*datasets, join='outer', copy=False, exclude=[dim])
# TODO: compute dask coords with a single invocation of dask.compute()
Copy link
Member

@shoyer shoyer Sep 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this would be even better :).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't bother as it would be a considerable complication and it would not do any good in the most common case where the "dask" variable is just a lazy load from disk

datasets = [_load_coords(ds) for ds in datasets]
Copy link
Member

@shoyer shoyer Sep 4, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need this if the argument coords='different' is set. Otherwise we don't load coords to determine if they are equal or not.


concat_over = _calc_concat_over(datasets, dim, data_vars, coords)

Expand Down
16 changes: 16 additions & 0 deletions xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,22 @@ def test_lazy_array(self):
actual = xr.concat([v[:2], v[2:]], 'x')
self.assertLazyAndAllClose(u, actual)

def test_concat_loads_coords(self):
# Test that concat() computes dask-based, non-index
# coordinates exactly once and loads them in the output,
# while leaving the input unaltered.
y = build_dask_array()
ds1 = Dataset(coords={'x': [1], 'y': ('x', y)})
ds2 = Dataset(coords={'x': [1], 'y': ('x', [2.0])})
assert kernel_call_count == 0
ds3 = xr.concat([ds1, ds2], dim='z')
# BUG fixed in #1532 where getattr('to_dataset')
# will cause non-index coords to be computed.
assert kernel_call_count == 2
assert ds1['y'].data is y
assert isinstance(ds3['y'].data, np.ndarray)
assert ds3['y'].values.tolist() == [[1.0], [2.0]]

def test_groupby(self):
u = self.eager_array
v = self.lazy_array
Expand Down