Skip to content
forked from pydata/xarray

Commit

Permalink
Add nanmedian for dask arrays (pydata#3604)
Browse files Browse the repository at this point in the history
* Add nanmedian for dask arrays

Close pydata#2999

* Fix tests.

* fix import

* Make sure that we don't rechunk the entire variable to one chunk

by reducing over all dimensions. Dask raises an error when axis=None
but not when axis=range(a.ndim).

* fix tests.

* Update whats-new.rst
  • Loading branch information
dcherian authored Dec 30, 2019
1 parent cc22f41 commit b3d3b44
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 8 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Breaking changes

New Features
~~~~~~~~~~~~
- Implement :py:func:`median` and :py:func:`nanmedian` for dask arrays. This works by rechunking
to a single chunk along all reduction axes. (:issue:`2999`).
By `Deepak Cherian <https://github.com/dcherian>`_.
- :py:func:`xarray.concat` now preserves attributes from the first Variable.
(:issue:`2575`, :issue:`2060`, :issue:`1614`)
By `Deepak Cherian <https://github.com/dcherian>`_.
Expand Down
83 changes: 81 additions & 2 deletions xarray/core/dask_array_compat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from distutils.version import LooseVersion
from typing import Iterable

import dask.array as da
import numpy as np
from dask import __version__ as dask_version

try:
import dask.array as da
from dask import __version__ as dask_version
except ImportError:
dask_version = "0.0.0"
da = None

if LooseVersion(dask_version) >= LooseVersion("2.0.0"):
meta_from_array = da.utils.meta_from_array
Expand Down Expand Up @@ -89,3 +95,76 @@ def meta_from_array(x, ndim=None, dtype=None):
meta = meta.astype(dtype)

return meta


if LooseVersion(dask_version) >= LooseVersion("2.8.1"):
median = da.median
else:
# Copied from dask v2.8.1
# Used under the terms of Dask's license, see licenses/DASK_LICENSE.
def median(a, axis=None, keepdims=False):
"""
This works by automatically chunking the reduced axes to a single chunk
and then calling ``numpy.median`` function across the remaining dimensions
"""

if axis is None:
raise NotImplementedError(
"The da.median function only works along an axis. "
"The full algorithm is difficult to do in parallel"
)

if not isinstance(axis, Iterable):
axis = (axis,)

axis = [ax + a.ndim if ax < 0 else ax for ax in axis]

a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})

result = a.map_blocks(
np.median,
axis=axis,
keepdims=keepdims,
drop_axis=axis if not keepdims else None,
chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)]
if keepdims
else None,
)

return result


if LooseVersion(dask_version) > LooseVersion("2.9.0"):
nanmedian = da.nanmedian
else:

def nanmedian(a, axis=None, keepdims=False):
"""
This works by automatically chunking the reduced axes to a single chunk
and then calling ``numpy.nanmedian`` function across the remaining dimensions
"""

if axis is None:
raise NotImplementedError(
"The da.nanmedian function only works along an axis. "
"The full algorithm is difficult to do in parallel"
)

if not isinstance(axis, Iterable):
axis = (axis,)

axis = [ax + a.ndim if ax < 0 else ax for ax in axis]

a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})

result = a.map_blocks(
np.nanmedian,
axis=axis,
keepdims=keepdims,
drop_axis=axis if not keepdims else None,
chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)]
if keepdims
else None,
)

return result
8 changes: 4 additions & 4 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import pandas as pd

from . import dask_array_ops, dtypes, npcompat, nputils
from . import dask_array_ops, dask_array_compat, dtypes, npcompat, nputils
from .nputils import nanfirst, nanlast
from .pycompat import dask_array_type

Expand Down Expand Up @@ -284,7 +284,7 @@ def _ignore_warnings_if(condition):
yield


def _create_nan_agg_method(name, coerce_strings=False):
def _create_nan_agg_method(name, dask_module=dask_array, coerce_strings=False):
from . import nanops

def f(values, axis=None, skipna=None, **kwargs):
Expand All @@ -301,7 +301,7 @@ def f(values, axis=None, skipna=None, **kwargs):
nanname = "nan" + name
func = getattr(nanops, nanname)
else:
func = _dask_or_eager_func(name)
func = _dask_or_eager_func(name, dask_module=dask_module)

try:
return func(values, axis=axis, **kwargs)
Expand Down Expand Up @@ -337,7 +337,7 @@ def f(values, axis=None, skipna=None, **kwargs):
std.numeric_only = True
var = _create_nan_agg_method("var")
var.numeric_only = True
median = _create_nan_agg_method("median")
median = _create_nan_agg_method("median", dask_module=dask_array_compat)
median.numeric_only = True
prod = _create_nan_agg_method("prod")
prod.numeric_only = True
Expand Down
12 changes: 11 additions & 1 deletion xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

try:
import dask.array as dask_array
from . import dask_array_compat
except ImportError:
dask_array = None
dask_array_compat = None # type: ignore


def _replace_nan(a, val):
Expand Down Expand Up @@ -141,7 +143,15 @@ def nanmean(a, axis=None, dtype=None, out=None):


def nanmedian(a, axis=None, out=None):
return _dask_or_eager_func("nanmedian", eager_module=nputils)(a, axis=axis)
# The dask algorithm works by rechunking to one chunk along axis
# Make sure we trigger the dask error when passing all dimensions
# so that we don't rechunk the entire array to one chunk and
# possibly blow memory
if axis is not None and len(np.atleast_1d(axis)) == a.ndim:
axis = None
return _dask_or_eager_func(
"nanmedian", dask_module=dask_array_compat, eager_module=nputils
)(a, axis=axis)


def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion xarray/tests/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,10 @@ def test_reduce(self):
self.assertLazyAndAllClose(u.argmin(dim="x"), actual)
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
self.assertLazyAndAllClose((u < 1).all("x"), (v < 1).all("x"))
with raises_regex(NotImplementedError, "dask"):
with raises_regex(NotImplementedError, "only works along an axis"):
v.median()
with raises_regex(NotImplementedError, "only works along an axis"):
v.median(v.dims)
with raise_if_dask_computes():
v.reduce(duck_array_ops.mean)

Expand Down

0 comments on commit b3d3b44

Please sign in to comment.