Skip to content

Commit

Permalink
Reimplement quantile with apply_ufunc (#3559)
Browse files Browse the repository at this point in the history
* Reimplement quantile with apply_ufunc

* Update xarray/core/variable.py

Co-Authored-By: Stephan Hoyer <shoyer@google.com>

* Update doc/whats-new.rst
  • Loading branch information
dcherian authored Nov 25, 2019
1 parent 8a148b6 commit 7dfdfca
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 69 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Breaking changes

New Features
~~~~~~~~~~~~
- :py:meth:`Dataset.quantile`, :py:meth:`DataArray.quantile` and ``GroupBy.quantile``
now work with dask Variables.
By `Deepak Cherian <https://github.com/dcherian>`_.


Bug fixes
Expand Down
6 changes: 1 addition & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5166,11 +5166,7 @@ def quantile(
new = self._replace_with_new_dims(
variables, coord_names=coord_names, attrs=attrs, indexes=indexes
)
if "quantile" in new.dims:
new.coords["quantile"] = Variable("quantile", q)
else:
new.coords["quantile"] = q
return new
return new.assign_coords(quantile=q)

def rank(self, dim, pct=False, keep_attrs=None):
"""Ranks the data.
Expand Down
63 changes: 34 additions & 29 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,40 +1716,45 @@ def quantile(self, q, dim=None, interpolation="linear", keep_attrs=None):
numpy.nanpercentile, pandas.Series.quantile, Dataset.quantile,
DataArray.quantile
"""
if isinstance(self.data, dask_array_type):
raise TypeError(
"quantile does not work for arrays stored as dask "
"arrays. Load the data via .compute() or .load() "
"prior to calling this method."
)

q = np.asarray(q, dtype=np.float64)

new_dims = list(self.dims)
if dim is not None:
axis = self.get_axis_num(dim)
if utils.is_scalar(dim):
new_dims.remove(dim)
else:
for d in dim:
new_dims.remove(d)
else:
axis = None
new_dims = []

# Only add the quantile dimension if q is array-like
if q.ndim != 0:
new_dims = ["quantile"] + new_dims

qs = np.nanpercentile(
self.data, q * 100.0, axis=axis, interpolation=interpolation
)
from .computation import apply_ufunc

if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
attrs = self._attrs if keep_attrs else None

return Variable(new_dims, qs, attrs)
scalar = utils.is_scalar(q)
q = np.atleast_1d(np.asarray(q, dtype=np.float64))

if dim is None:
dim = self.dims

if utils.is_scalar(dim):
dim = [dim]

def _wrapper(npa, **kwargs):
# move quantile axis to end. required for apply_ufunc
return np.moveaxis(np.nanpercentile(npa, **kwargs), 0, -1)

axis = np.arange(-1, -1 * len(dim) - 1, -1)
result = apply_ufunc(
_wrapper,
self,
input_core_dims=[dim],
exclude_dims=set(dim),
output_core_dims=[["quantile"]],
output_dtypes=[np.float64],
output_sizes={"quantile": len(q)},
dask="parallelized",
kwargs={"q": q * 100, "axis": axis, "interpolation": interpolation},
)

# for backward compatibility
result = result.transpose("quantile", ...)
if scalar:
result = result.squeeze("quantile")
if keep_attrs:
result.attrs = self._attrs
return result

def rank(self, dim, pct=False):
"""Ranks the data.
Expand Down
27 changes: 16 additions & 11 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from xarray.core import dtypes
from xarray.core.common import full_like
from xarray.core.indexes import propagate_indexes
from xarray.core.utils import is_scalar

from xarray.tests import (
LooseVersion,
ReturnItem,
Expand Down Expand Up @@ -2330,17 +2332,20 @@ def test_reduce_out(self):
with pytest.raises(TypeError):
orig.mean(out=np.ones(orig.shape))

def test_quantile(self):
for q in [0.25, [0.50], [0.25, 0.75]]:
for axis, dim in zip(
[None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]
):
actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True)
expected = np.nanpercentile(
self.dv.values, np.array(q) * 100, axis=axis
)
np.testing.assert_allclose(actual.values, expected)
assert actual.attrs == self.attrs
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
@pytest.mark.parametrize(
"axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]])
)
def test_quantile(self, q, axis, dim):
actual = DataArray(self.va).quantile(q, dim=dim, keep_attrs=True)
expected = np.nanpercentile(self.dv.values, np.array(q) * 100, axis=axis)
np.testing.assert_allclose(actual.values, expected)
if is_scalar(q):
assert "quantile" not in actual.dims
else:
assert "quantile" in actual.dims

assert actual.attrs == self.attrs

def test_reduce_keep_attrs(self):
# Test dropped attrs
Expand Down
28 changes: 16 additions & 12 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from xarray.core.common import duck_array_ops, full_like
from xarray.core.npcompat import IS_NEP18_ACTIVE
from xarray.core.pycompat import integer_types
from xarray.core.utils import is_scalar

from . import (
InaccessibleArray,
Expand Down Expand Up @@ -4575,21 +4576,24 @@ def test_reduce_keepdims(self):
)
assert_identical(expected, actual)

def test_quantile(self):

@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
def test_quantile(self, q):
ds = create_test_data(seed=123)

for q in [0.25, [0.50], [0.25, 0.75]]:
for dim in [None, "dim1", ["dim1"]]:
ds_quantile = ds.quantile(q, dim=dim)
assert "quantile" in ds_quantile
for var, dar in ds.data_vars.items():
assert var in ds_quantile
assert_identical(ds_quantile[var], dar.quantile(q, dim=dim))
dim = ["dim1", "dim2"]
for dim in [None, "dim1", ["dim1"]]:
ds_quantile = ds.quantile(q, dim=dim)
assert "dim3" in ds_quantile.dims
assert all(d not in ds_quantile.dims for d in dim)
if is_scalar(q):
assert "quantile" not in ds_quantile.dims
else:
assert "quantile" in ds_quantile.dims

for var, dar in ds.data_vars.items():
assert var in ds_quantile
assert_identical(ds_quantile[var], dar.quantile(q, dim=dim))
dim = ["dim1", "dim2"]
ds_quantile = ds.quantile(q, dim=dim)
assert "dim3" in ds_quantile.dims
assert all(d not in ds_quantile.dims for d in dim)

@requires_bottleneck
def test_rank(self):
Expand Down
33 changes: 21 additions & 12 deletions xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PandasIndexAdapter,
VectorizedIndexer,
)
from xarray.core.pycompat import dask_array_type
from xarray.core.utils import NDArrayMixin
from xarray.core.variable import as_compatible_data, as_variable
from xarray.tests import requires_bottleneck
Expand Down Expand Up @@ -1492,23 +1493,31 @@ def test_reduce(self):
with pytest.warns(DeprecationWarning, match="allow_lazy is deprecated"):
v.mean(dim="x", allow_lazy=False)

def test_quantile(self):
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
@pytest.mark.parametrize(
"axis, dim", zip([None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]])
)
def test_quantile(self, q, axis, dim):
v = Variable(["x", "y"], self.d)
for q in [0.25, [0.50], [0.25, 0.75]]:
for axis, dim in zip(
[None, 0, [0], [0, 1]], [None, "x", ["x"], ["x", "y"]]
):
actual = v.quantile(q, dim=dim)
actual = v.quantile(q, dim=dim)
expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis)
np.testing.assert_allclose(actual.values, expected)

expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis)
np.testing.assert_allclose(actual.values, expected)
@requires_dask
@pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
@pytest.mark.parametrize("axis, dim", [[1, "y"], [[1], ["y"]]])
def test_quantile_dask(self, q, axis, dim):
v = Variable(["x", "y"], self.d).chunk({"x": 2})
actual = v.quantile(q, dim=dim)
assert isinstance(actual.data, dask_array_type)
expected = np.nanpercentile(self.d, np.array(q) * 100, axis=axis)
np.testing.assert_allclose(actual.values, expected)

@requires_dask
def test_quantile_dask_raises(self):
# regression for GH1524
v = Variable(["x", "y"], self.d).chunk(2)
def test_quantile_chunked_dim_error(self):
v = Variable(["x", "y"], self.d).chunk({"x": 2})

with raises_regex(TypeError, "arrays stored as dask"):
with raises_regex(ValueError, "dimension 'x'"):
v.quantile(0.5, dim="x")

@requires_dask
Expand Down

0 comments on commit 7dfdfca

Please sign in to comment.