Skip to content

Commit

Permalink
Lazy rolling_window (SciTools#5775)
Browse files Browse the repository at this point in the history
* Lazy rolling_window

* Add test and whatsnew entry
  • Loading branch information
bouweandela authored and pp-mo committed Mar 11, 2024
1 parent d9b0680 commit f398de9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 21 deletions.
3 changes: 2 additions & 1 deletion docs/src/whatsnew/latest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ This document explains the changes made to Iris for this release
🚀 Performance Enhancements
===========================

#. N/A
#. `@bouweandela`_ made :func:`iris.util.rolling_window` work with lazy arrays.
(:pull:`5775`)


🔥 Deprecations
Expand Down
7 changes: 7 additions & 0 deletions lib/iris/tests/unit/util/test_rolling_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# importing anything else
import iris.tests as tests # isort:skip

import dask.array as da
import numpy as np
import numpy.ma as ma

Expand Down Expand Up @@ -35,6 +36,12 @@ def test_2d(self):
result = rolling_window(a, window=3, axis=1)
self.assertArrayEqual(result, expected_result)

def test_3d_lazy(self):
a = da.arange(2 * 3 * 4).reshape((2, 3, 4))
expected_result = np.arange(2 * 3 * 4).reshape((1, 2, 3, 4))
result = rolling_window(a, window=2, axis=0).compute()
self.assertArrayEqual(result, expected_result)

def test_1d_masked(self):
# 1-d masked array input
a = ma.array([0, 1, 2, 3, 4], mask=[0, 0, 1, 0, 0], dtype=np.int32)
Expand Down
45 changes: 25 additions & 20 deletions lib/iris/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This file is part of Iris and is released under the BSD license.
# See LICENSE in the root of the repository for full licensing details.
"""Miscellaneous utility functions."""
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from collections.abc import Hashable, Iterable
Expand Down Expand Up @@ -282,7 +283,12 @@ def guess_coord_axis(coord):
return axis


def rolling_window(a, window=1, step=1, axis=-1):
def rolling_window(
a: np.ndarray | da.Array,
window: int = 1,
step: int = 1,
axis: int = -1,
) -> np.ndarray | da.Array:
"""Make an ndarray with a rolling window of the last dimension.
Parameters
Expand Down Expand Up @@ -323,34 +329,33 @@ def rolling_window(a, window=1, step=1, axis=-1):
See more at :doc:`/userguide/real_and_lazy_data`.
"""
# NOTE: The implementation of this function originates from
# https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011
if window < 1:
raise ValueError("`window` must be at least 1.")
if window > a.shape[axis]:
raise ValueError("`window` is too long.")
if step < 1:
raise ValueError("`step` must be at least 1.")
axis = axis % a.ndim
num_windows = (a.shape[axis] - window + step) // step
shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :]
strides = (
a.strides[:axis]
+ (step * a.strides[axis], a.strides[axis])
+ a.strides[axis + 1 :]
array_module = da if isinstance(a, da.Array) else np
steps = tuple(
slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim)
)
rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
if ma.isMaskedArray(a):
mask = ma.getmaskarray(a)
strides = (
mask.strides[:axis]
+ (step * mask.strides[axis], mask.strides[axis])
+ mask.strides[axis + 1 :]
)
rw = ma.array(
rw,
mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides),

def _rolling_window(array):
return array_module.moveaxis(
array_module.lib.stride_tricks.sliding_window_view(
array,
window_shape=window,
axis=axis,
)[steps],
-1,
axis + 1,
)

rw = _rolling_window(a)
if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray):
mask = _rolling_window(array_module.ma.getmaskarray(a))
rw = array_module.ma.masked_array(rw, mask)
return rw


Expand Down

0 comments on commit f398de9

Please sign in to comment.