Skip to content

Commit

Permalink
nd-rolling (#4219)
Browse files Browse the repository at this point in the history
* nd-rolling

* remove unnecessary print

* black

* finding a bug...

* make tests for ndrolling pass

* make center and window_dim a dict

* A cleanup.

* Revert test_units

* make test pass

* More tests.

* more docs

* mypy

* improve whatsnew

* Improve doc

* Support nd-rolling in dask correctly

* Cleanup according to max's comment

* flake8

* black

* stop using either_dict_or_kwargs

* Better tests.

* typo

* mypy

* typo2
  • Loading branch information
fujiisoup authored Aug 8, 2020
1 parent e04e21d commit 1d3dee0
Show file tree
Hide file tree
Showing 12 changed files with 362 additions and 133 deletions.
9 changes: 8 additions & 1 deletion doc/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,16 @@ a value when aggregating:
r = arr.rolling(y=3, center=True, min_periods=2)
r.mean()
From version 0.17, xarray supports multidimensional rolling,

.. ipython:: python
r = arr.rolling(x=2, y=3, min_periods=2)
r.mean()
.. tip::

Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects.
Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects with 1d-rolling.

.. _bottleneck: https://github.com/pydata/bottleneck/

Expand Down
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ Breaking changes

New Features
~~~~~~~~~~~~
- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling`
now accept more than 1 dimension.(:pull:`4219`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
- Build :py:meth:`CFTimeIndex.__repr__` explicitly as :py:class:`pandas.Index`. Add ``calendar`` as a new
property for :py:class:`CFTimeIndex` and show ``calendar`` and ``length`` in
:py:meth:`CFTimeIndex.__repr__` (:issue:`2416`, :pull:`4092`)
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ def rolling(
self,
dim: Mapping[Hashable, int] = None,
min_periods: int = None,
center: bool = False,
center: Union[bool, Mapping[Hashable, bool]] = False,
keep_attrs: bool = None,
**window_kwargs: int,
):
Expand All @@ -802,7 +802,7 @@ def rolling(
Minimum number of observations in window required to have a value
(otherwise result is NA). The default, None, is equivalent to
setting min_periods equal to the size of the window.
center : boolean, default False
center : boolean, or a mapping, default False
Set the labels at the center of the window.
keep_attrs : bool, optional
If True, the object's attributes (`attrs`) will be copied from
Expand Down
113 changes: 62 additions & 51 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,69 +32,80 @@ def rolling_window(a, axis, window, center, fill_value):
"""
import dask.array as da

if not hasattr(axis, "__len__"):
axis = [axis]
window = [window]
center = [center]

orig_shape = a.shape
if axis < 0:
axis = a.ndim + axis
depth = {d: 0 for d in range(a.ndim)}
depth[axis] = int(window / 2)
# For evenly sized window, we need to crop the first point of each block.
offset = 1 if window % 2 == 0 else 0

if depth[axis] > min(a.chunks[axis]):
raise ValueError(
"For window size %d, every chunk should be larger than %d, "
"but the smallest chunk size is %d. Rechunk your array\n"
"with a larger chunk size or a chunk size that\n"
"more evenly divides the shape of your array."
% (window, depth[axis], min(a.chunks[axis]))
)

# Although da.overlap pads values to boundaries of the array,
# the size of the generated array is smaller than what we want
# if center == False.
if center:
start = int(window / 2) # 10 -> 5, 9 -> 4
end = window - 1 - start
else:
start, end = window - 1, 0
pad_size = max(start, end) + offset - depth[axis]
drop_size = 0
# pad_size becomes more than 0 when the overlapped array is smaller than
# needed. In this case, we need to enlarge the original array by padding
# before overlapping.
if pad_size > 0:
if pad_size < depth[axis]:
# overlapping requires each chunk larger than depth. If pad_size is
# smaller than the depth, we enlarge this and truncate it later.
drop_size = depth[axis] - pad_size
pad_size = depth[axis]
shape = list(a.shape)
shape[axis] = pad_size
chunks = list(a.chunks)
chunks[axis] = (pad_size,)
fill_array = da.full(shape, fill_value, dtype=a.dtype, chunks=chunks)
a = da.concatenate([fill_array, a], axis=axis)

offset = [0] * a.ndim
drop_size = [0] * a.ndim
pad_size = [0] * a.ndim
for ax, win, cent in zip(axis, window, center):
if ax < 0:
ax = a.ndim + ax
depth[ax] = int(win / 2)
# For evenly sized window, we need to crop the first point of each block.
offset[ax] = 1 if win % 2 == 0 else 0

if depth[ax] > min(a.chunks[ax]):
raise ValueError(
"For window size %d, every chunk should be larger than %d, "
"but the smallest chunk size is %d. Rechunk your array\n"
"with a larger chunk size or a chunk size that\n"
"more evenly divides the shape of your array."
% (win, depth[ax], min(a.chunks[ax]))
)

# Although da.overlap pads values to boundaries of the array,
# the size of the generated array is smaller than what we want
# if center == False.
if cent:
start = int(win / 2) # 10 -> 5, 9 -> 4
end = win - 1 - start
else:
start, end = win - 1, 0
pad_size[ax] = max(start, end) + offset[ax] - depth[ax]
drop_size[ax] = 0
# pad_size becomes more than 0 when the overlapped array is smaller than
# needed. In this case, we need to enlarge the original array by padding
# before overlapping.
if pad_size[ax] > 0:
if pad_size[ax] < depth[ax]:
# overlapping requires each chunk larger than depth. If pad_size is
# smaller than the depth, we enlarge this and truncate it later.
drop_size[ax] = depth[ax] - pad_size[ax]
pad_size[ax] = depth[ax]

# TODO maybe following two lines can be summarized.
a = da.pad(
a, [(p, 0) for p in pad_size], mode="constant", constant_values=fill_value
)
boundary = {d: fill_value for d in range(a.ndim)}

# create overlap arrays
ag = da.overlap.overlap(a, depth=depth, boundary=boundary)

# apply rolling func
def func(x, window, axis=-1):
def func(x, window, axis):
x = np.asarray(x)
rolling = nputils._rolling_window(x, window, axis)
return rolling[(slice(None),) * axis + (slice(offset, None),)]

chunks = list(a.chunks)
chunks.append(window)
index = [slice(None)] * x.ndim
for ax, win in zip(axis, window):
x = nputils._rolling_window(x, win, ax)
index[ax] = slice(offset[ax], None)
return x[tuple(index)]

chunks = list(a.chunks) + window
new_axis = [a.ndim + i for i in range(len(axis))]
out = ag.map_blocks(
func, dtype=a.dtype, new_axis=a.ndim, chunks=chunks, window=window, axis=axis
func, dtype=a.dtype, new_axis=new_axis, chunks=chunks, window=window, axis=axis
)

# crop boundary.
index = (slice(None),) * axis + (slice(drop_size, drop_size + orig_shape[axis]),)
return out[index]
index = [slice(None)] * a.ndim
for ax in axis:
index[ax] = slice(drop_size[ax], drop_size[ax] + orig_shape[ax])
return out[tuple(index)]


def least_squares(lhs, rhs, rcond=None, skipna=False):
Expand Down
2 changes: 1 addition & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5968,7 +5968,7 @@ def polyfit(
skipna_da = np.any(da.isnull())

dims_to_stack = [dimname for dimname in da.dims if dimname != dim]
stacked_coords = {}
stacked_coords: Dict[Hashable, DataArray] = {}
if dims_to_stack:
stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked")
rhs = da.transpose(dim, *dims_to_stack).stack(
Expand Down
22 changes: 15 additions & 7 deletions xarray/core/nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,22 @@ def __setitem__(self, key, value):
def rolling_window(a, axis, window, center, fill_value):
""" rolling window with padding. """
pads = [(0, 0) for s in a.shape]
if center:
start = int(window / 2) # 10 -> 5, 9 -> 4
end = window - 1 - start
pads[axis] = (start, end)
else:
pads[axis] = (window - 1, 0)
if not hasattr(axis, "__len__"):
axis = [axis]
window = [window]
center = [center]

for ax, win, cent in zip(axis, window, center):
if cent:
start = int(win / 2) # 10 -> 5, 9 -> 4
end = win - 1 - start
pads[ax] = (start, end)
else:
pads[ax] = (win - 1, 0)
a = np.pad(a, pads, mode="constant", constant_values=fill_value)
return _rolling_window(a, window, axis)
for ax, win in zip(axis, window):
a = _rolling_window(a, win, ax)
return a


def _rolling_window(a, window, axis=-1):
Expand Down
Loading

0 comments on commit 1d3dee0

Please sign in to comment.