diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5f3e855f8b1..2cb626cfd8c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -189,6 +189,13 @@ Bug fixes ``rtol`` arguments when called on ``DataArray`` objects. By `Stephan Hoyer `_. +- :py:func:`~xarray.concat` was computing variables that aren't in memory + (e.g. dask-based) multiple times; :py:func:`~xarray.open_mfdataset` + was loading them multiple times from disk. Now, both functions will instead + load them at most once and, if they do, store them in memory in the + concatenated array/dataset (:issue:`1521`). + By `Guido Imperiale `_. + - xarray ``quantile`` methods now properly raise a ``TypeError`` when applied to objects with data stored as ``dask`` arrays (:issue:`1529`). By `Joe Hamman `_. diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 04b46a6624b..007b9640e20 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -148,68 +148,85 @@ def _calc_concat_over(datasets, dim, data_vars, coords): Determine which dataset variables need to be concatenated in the result, and which can simply be taken from the first dataset. """ - def process_subset_opt(opt, subset): - if subset == 'coords': - subset_long_name = 'coordinates' - else: - subset_long_name = 'data variables' + # Return values + concat_over = set() + equals = {} + + if dim in datasets[0]: + concat_over.add(dim) + for ds in datasets: + concat_over.update(k for k, v in ds.variables.items() + if dim in v.dims) + def process_subset_opt(opt, subset): if isinstance(opt, basestring): if opt == 'different': - def differs(vname): - # simple helper function which compares a variable - # across all datasets and indicates whether that - # variable differs or not. - v = datasets[0].variables[vname] - return any(not ds.variables[vname].equals(v) - for ds in datasets[1:]) # all nonindexes that are not the same in each dataset - concat_new = set(k for k in getattr(datasets[0], subset) - if k not in concat_over and differs(k)) + for k in getattr(datasets[0], subset): + if k not in concat_over: + # Compare the variable of all datasets vs. the one + # of the first dataset. Perform the minimum amount of + # loads in order to avoid multiple loads from disk while + # keeping the RAM footprint low. + v_lhs = datasets[0].variables[k].load() + # We'll need to know later on if variables are equal. + computed = [] + for ds_rhs in datasets[1:]: + v_rhs = ds_rhs.variables[k].compute() + computed.append(v_rhs) + if not v_lhs.equals(v_rhs): + concat_over.add(k) + equals[k] = False + # computed variables are not to be re-computed + # again in the future + for ds, v in zip(datasets[1:], computed): + ds.variables[k].data = v.data + break + else: + equals[k] = True + elif opt == 'all': - concat_new = (set(getattr(datasets[0], subset)) - - set(datasets[0].dims)) + concat_over.update(set(getattr(datasets[0], subset)) - + set(datasets[0].dims)) elif opt == 'minimal': - concat_new = set() + pass else: - raise ValueError("unexpected value for concat_%s: %s" - % (subset, opt)) + raise ValueError("unexpected value for %s: %s" % (subset, opt)) else: invalid_vars = [k for k in opt if k not in getattr(datasets[0], subset)] if invalid_vars: - raise ValueError('some variables in %s are not ' - '%s on the first dataset: %s' - % (subset, subset_long_name, invalid_vars)) - concat_new = set(opt) - return concat_new + if subset == 'coords': + raise ValueError( + 'some variables in coords are not coordinates on ' + 'the first dataset: %s' % invalid_vars) + else: + raise ValueError( + 'some variables in data_vars are not data variables on ' + 'the first dataset: %s' % invalid_vars) + concat_over.update(opt) - concat_over = set() - for ds in datasets: - concat_over.update(k for k, v in ds.variables.items() - if dim in v.dims) - concat_over.update(process_subset_opt(data_vars, 'data_vars')) - concat_over.update(process_subset_opt(coords, 'coords')) - if dim in datasets[0]: - concat_over.add(dim) - return concat_over + process_subset_opt(data_vars, 'data_vars') + process_subset_opt(coords, 'coords') + return concat_over, equals def _dataset_concat(datasets, dim, data_vars, coords, compat, positions): """ Concatenate a sequence of datasets along a new or existing dimension """ - from .dataset import Dataset, as_dataset + from .dataset import Dataset if compat not in ['equals', 'identical']: raise ValueError("compat=%r invalid: must be 'equals' " "or 'identical'" % compat) dim, coord = _calc_concat_dim_coord(dim) - datasets = [as_dataset(ds) for ds in datasets] + # Make sure we're working on a copy (we'll be loading variables) + datasets = [ds.copy() for ds in datasets] datasets = align(*datasets, join='outer', copy=False, exclude=[dim]) - concat_over = _calc_concat_over(datasets, dim, data_vars, coords) + concat_over, equals = _calc_concat_over(datasets, dim, data_vars, coords) def insert_result_variable(k, v): assert isinstance(v, Variable) @@ -239,11 +256,25 @@ def insert_result_variable(k, v): elif (k in result_coord_names) != (k in ds.coords): raise ValueError('%r is a coordinate in some datasets but not ' 'others' % k) - elif (k in result_vars and k != dim and - not getattr(v, compat)(result_vars[k])): - verb = 'equal' if compat == 'equals' else compat - raise ValueError( - 'variable %r not %s across datasets' % (k, verb)) + elif k in result_vars and k != dim: + # Don't use Variable.identical as it internally invokes + # Variable.equals, and we may already know the answer + if compat == 'identical' and not utils.dict_equiv( + v.attrs, result_vars[k].attrs): + raise ValueError( + 'variable %s not identical across datasets' % k) + + # Proceed with equals() + try: + # May be populated when using the "different" method + is_equal = equals[k] + except KeyError: + result_vars[k].load() + is_equal = v.equals(result_vars[k]) + if not is_equal: + raise ValueError( + 'variable %s not equal across datasets' % k) + # we've already verified everything is consistent; now, calculate # shared dimension sizes so we can expand the necessary variables diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 0244c7ec8af..59b0d302151 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -250,6 +250,77 @@ def test_lazy_array(self): actual = xr.concat([v[:2], v[2:]], 'x') self.assertLazyAndAllClose(u, actual) + def test_concat_loads_variables(self): + # Test that concat() computes not-in-memory variables at most once + # and loads them in the output, while leaving the input unaltered. + d1 = build_dask_array('d1') + c1 = build_dask_array('c1') + d2 = build_dask_array('d2') + c2 = build_dask_array('c2') + d3 = build_dask_array('d3') + c3 = build_dask_array('c3') + # Note: c is a non-index coord. + # Index coords are loaded by IndexVariable.__init__. + ds1 = Dataset(data_vars={'d': ('x', d1)}, coords={'c': ('x', c1)}) + ds2 = Dataset(data_vars={'d': ('x', d2)}, coords={'c': ('x', c2)}) + ds3 = Dataset(data_vars={'d': ('x', d3)}, coords={'c': ('x', c3)}) + + assert kernel_call_count == 0 + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different', + coords='different') + # each kernel is computed exactly once + assert kernel_call_count == 6 + # variables are loaded in the output + assert isinstance(out['d'].data, np.ndarray) + assert isinstance(out['c'].data, np.ndarray) + + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='all', coords='all') + # no extra kernel calls + assert kernel_call_count == 6 + assert isinstance(out['d'].data, dask.array.Array) + assert isinstance(out['c'].data, dask.array.Array) + + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=['d'], coords=['c']) + # no extra kernel calls + assert kernel_call_count == 6 + assert isinstance(out['d'].data, dask.array.Array) + assert isinstance(out['c'].data, dask.array.Array) + + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars=[], coords=[]) + # variables are loaded once as we are validing that they're identical + assert kernel_call_count == 12 + assert isinstance(out['d'].data, np.ndarray) + assert isinstance(out['c'].data, np.ndarray) + + out = xr.concat([ds1, ds2, ds3], dim='n', data_vars='different', + coords='different', compat='identical') + # compat=identical doesn't do any more kernel calls than compat=equals + assert kernel_call_count == 18 + assert isinstance(out['d'].data, np.ndarray) + assert isinstance(out['c'].data, np.ndarray) + + # When the test for different turns true halfway through, + # stop computing variables as it would not have any benefit + ds4 = Dataset(data_vars={'d': ('x', [2.0])}, coords={'c': ('x', [2.0])}) + out = xr.concat([ds1, ds2, ds4, ds3], dim='n', data_vars='different', + coords='different') + # the variables of ds1 and ds2 were computed, but those of ds3 didn't + assert kernel_call_count == 22 + assert isinstance(out['d'].data, dask.array.Array) + assert isinstance(out['c'].data, dask.array.Array) + # the data of ds1 and ds2 was loaded into numpy and then + # concatenated to the data of ds3. Thus, only ds3 is computed now. + out.compute() + assert kernel_call_count == 24 + + # Finally, test that riginals are unaltered + assert ds1['d'].data is d1 + assert ds1['c'].data is c1 + assert ds2['d'].data is d2 + assert ds2['c'].data is c2 + assert ds3['d'].data is d3 + assert ds3['c'].data is c3 + def test_groupby(self): if LooseVersion(dask.__version__) == LooseVersion('0.15.3'): pytest.xfail('upstream bug in dask: ' @@ -517,10 +588,11 @@ def test_dask_kwargs_dataset(method): kernel_call_count = 0 -def kernel(): +def kernel(name): """Dask kernel to test pickling/unpickling and __repr__. Must be global to make it pickleable. """ + print("kernel(%s)" % name) global kernel_call_count kernel_call_count += 1 return np.ones(1, dtype=np.int64) @@ -530,5 +602,5 @@ def build_dask_array(name): global kernel_call_count kernel_call_count = 0 return dask.array.Array( - dask={(name, 0): (kernel, )}, name=name, + dask={(name, 0): (kernel, name)}, name=name, chunks=((1,),), dtype=np.int64)