diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 75a273bfdb4..57d29292f72 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -135,6 +135,8 @@ Bug fixes - Fix error that arises when using open_mfdataset on a series of netcdf files having differing values for a variable attribute of type list. (:issue:`3034`) By `Hasan Ahmad `_. +- Prevent :py:meth:`~xarray.DataArray.argmax` and :py:meth:`~xarray.DataArray.argmin` from calling + dask compute (:issue:`3237`). By `Ulrich Herter `_. .. _whats-new.0.12.3: diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 9ba4eae29ae..17240faf007 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -88,38 +88,21 @@ def nanmax(a, axis=None, out=None): def nanargmin(a, axis=None): - fill_value = dtypes.get_pos_infinity(a.dtype) if a.dtype.kind == "O": + fill_value = dtypes.get_pos_infinity(a.dtype) return _nan_argminmax_object("argmin", fill_value, a, axis=axis) - a, mask = _replace_nan(a, fill_value) - if isinstance(a, dask_array_type): - res = dask_array.argmin(a, axis=axis) - else: - res = np.argmin(a, axis=axis) - if mask is not None: - mask = mask.all(axis=axis) - if mask.any(): - raise ValueError("All-NaN slice encountered") - return res + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanargmin(a, axis=axis) def nanargmax(a, axis=None): - fill_value = dtypes.get_neg_infinity(a.dtype) if a.dtype.kind == "O": + fill_value = dtypes.get_neg_infinity(a.dtype) return _nan_argminmax_object("argmax", fill_value, a, axis=axis) - a, mask = _replace_nan(a, fill_value) - if isinstance(a, dask_array_type): - res = dask_array.argmax(a, axis=axis) - else: - res = np.argmax(a, axis=axis) - - if mask is not None: - mask = mask.all(axis=axis) - if mask.any(): - raise ValueError("All-NaN slice encountered") - return res + module = dask_array if isinstance(a, dask_array_type) else nputils + return module.nanargmax(a, axis=axis) def nansum(a, axis=None, dtype=None, out=None, min_count=None): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 769af03fe6a..df36c98f94c 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -237,3 +237,5 @@ def f(values, axis=None, **kwargs): nanprod = _create_bottleneck_method("nanprod") nancumsum = _create_bottleneck_method("nancumsum") nancumprod = _create_bottleneck_method("nancumprod") +nanargmin = _create_bottleneck_method("nanargmin") +nanargmax = _create_bottleneck_method("nanargmax") diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index e3fc6f65e0f..d105765481e 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -27,14 +27,49 @@ dd = pytest.importorskip("dask.dataframe") +class CountingScheduler: + """ Simple dask scheduler counting the number of computes. + + Reference: https://stackoverflow.com/questions/53289286/ """ + + def __init__(self, max_computes=0): + self.total_computes = 0 + self.max_computes = max_computes + + def __call__(self, dsk, keys, **kwargs): + self.total_computes += 1 + if self.total_computes > self.max_computes: + raise RuntimeError( + "Too many computes. Total: %d > max: %d." + % (self.total_computes, self.max_computes) + ) + return dask.get(dsk, keys, **kwargs) + + +def _set_dask_scheduler(scheduler=dask.get): + """ Backwards compatible way of setting scheduler. """ + if LooseVersion(dask.__version__) >= LooseVersion("0.18.0"): + return dask.config.set(scheduler=scheduler) + return dask.set_options(get=scheduler) + + +def raise_if_dask_computes(max_computes=0): + scheduler = CountingScheduler(max_computes) + return _set_dask_scheduler(scheduler) + + +def test_raise_if_dask_computes(): + data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2)) + with raises_regex(RuntimeError, "Too many computes"): + with raise_if_dask_computes(): + data.compute() + + class DaskTestCase: def assertLazyAnd(self, expected, actual, test): - - with ( - dask.config.set(scheduler="single-threaded") - if LooseVersion(dask.__version__) >= LooseVersion("0.18.0") - else dask.set_options(get=dask.get) - ): + with _set_dask_scheduler(dask.get): + # dask.get is the syncronous scheduler, which get's set also by + # dask.config.set(scheduler="syncronous") in current versions. test(actual, expected) if isinstance(actual, Dataset): @@ -174,7 +209,12 @@ def test_reduce(self): v = self.lazy_var self.assertLazyAndAllClose(u.mean(), v.mean()) self.assertLazyAndAllClose(u.std(), v.std()) - self.assertLazyAndAllClose(u.argmax(dim="x"), v.argmax(dim="x")) + with raise_if_dask_computes(): + actual = v.argmax(dim="x") + self.assertLazyAndAllClose(u.argmax(dim="x"), actual) + with raise_if_dask_computes(): + actual = v.argmin(dim="x") + self.assertLazyAndAllClose(u.argmin(dim="x"), actual) self.assertLazyAndAllClose((u > 1).any(), (v > 1).any()) self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x")) with raises_regex(NotImplementedError, "dask"):