diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4b5bb1e491f..2cd4284f3cb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -70,8 +70,8 @@ New Features :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`) By `Todd Jennings `_ - Support dask handling for :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, - :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`) - By `Kai Mühlbauer `_. + :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`, :pull:`4135`) + By `Kai Mühlbauer `_ and `Pascal Bourgault `_. - More support for unit aware arrays with pint (:pull:`3643`, :pull:`3975`) By `Justus Magin `_. - Support overriding existing variables in ``to_zarr()`` with ``mode='a'`` even diff --git a/xarray/core/computation.py b/xarray/core/computation.py index cecd4fd8e70..4f4fd475c82 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1563,7 +1563,7 @@ def _calc_idxminmax( chunks = dict(zip(array.dims, array.chunks)) dask_coord = dask.array.from_array(array[dim].data, chunks=chunks[dim]) - res = indx.copy(data=dask_coord[(indx.data,)]) + res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape)) # we need to attach back the dim name res.name = dim else: diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8fc37ac458d..d942667a4c7 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -5257,6 +5257,25 @@ def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): assert_identical(result7, expected7) +class TestReduceND(TestReduce): + @pytest.mark.parametrize("op", ["idxmin", "idxmax"]) + @pytest.mark.parametrize("ndim", [3, 5]) + def test_idxminmax_dask(self, op, ndim): + if not has_dask: + pytest.skip("requires dask") + + ar0_raw = xr.DataArray( + np.random.random_sample(size=[10] * ndim), + dims=[i for i in "abcdefghij"[: ndim - 1]] + ["x"], + coords={"x": np.arange(10)}, + attrs=self.attrs, + ) + + ar0_dsk = ar0_raw.chunk({}) + # Assert idx is the same with dask and without + assert_equal(getattr(ar0_dsk, op)(dim="x"), getattr(ar0_raw, op)(dim="x")) + + @pytest.fixture(params=[1]) def da(request): if request.param == 1: