From f398de92bbae41a4c4a4c7095b63f70ed1e2247c Mon Sep 17 00:00:00 2001 From: Bouwe Andela Date: Fri, 1 Mar 2024 11:26:37 +0100 Subject: [PATCH] Lazy rolling_window (#5775) * Lazy rolling_window * Add test and whatsnew entry --- docs/src/whatsnew/latest.rst | 3 +- .../tests/unit/util/test_rolling_window.py | 7 +++ lib/iris/util.py | 45 ++++++++++--------- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/docs/src/whatsnew/latest.rst b/docs/src/whatsnew/latest.rst index 044778d426f..76745f86676 100644 --- a/docs/src/whatsnew/latest.rst +++ b/docs/src/whatsnew/latest.rst @@ -48,7 +48,8 @@ This document explains the changes made to Iris for this release 🚀 Performance Enhancements =========================== -#. N/A +#. `@bouweandela`_ made :func:`iris.util.rolling_window` work with lazy arrays. + (:pull:`5775`) 🔥 Deprecations diff --git a/lib/iris/tests/unit/util/test_rolling_window.py b/lib/iris/tests/unit/util/test_rolling_window.py index 8a017e4e08f..d70b398ed5a 100644 --- a/lib/iris/tests/unit/util/test_rolling_window.py +++ b/lib/iris/tests/unit/util/test_rolling_window.py @@ -8,6 +8,7 @@ # importing anything else import iris.tests as tests # isort:skip +import dask.array as da import numpy as np import numpy.ma as ma @@ -35,6 +36,12 @@ def test_2d(self): result = rolling_window(a, window=3, axis=1) self.assertArrayEqual(result, expected_result) + def test_3d_lazy(self): + a = da.arange(2 * 3 * 4).reshape((2, 3, 4)) + expected_result = np.arange(2 * 3 * 4).reshape((1, 2, 3, 4)) + result = rolling_window(a, window=2, axis=0).compute() + self.assertArrayEqual(result, expected_result) + def test_1d_masked(self): # 1-d masked array input a = ma.array([0, 1, 2, 3, 4], mask=[0, 0, 1, 0, 0], dtype=np.int32) diff --git a/lib/iris/util.py b/lib/iris/util.py index b98db8090fd..87837f6111f 100644 --- a/lib/iris/util.py +++ b/lib/iris/util.py @@ -3,6 +3,7 @@ # This file is part of Iris and is released under the BSD license. # See LICENSE in the root of the repository for full licensing details. """Miscellaneous utility functions.""" +from __future__ import annotations from abc import ABCMeta, abstractmethod from collections.abc import Hashable, Iterable @@ -282,7 +283,12 @@ def guess_coord_axis(coord): return axis -def rolling_window(a, window=1, step=1, axis=-1): +def rolling_window( + a: np.ndarray | da.Array, + window: int = 1, + step: int = 1, + axis: int = -1, +) -> np.ndarray | da.Array: """Make an ndarray with a rolling window of the last dimension. Parameters @@ -323,8 +329,6 @@ def rolling_window(a, window=1, step=1, axis=-1): See more at :doc:`/userguide/real_and_lazy_data`. """ - # NOTE: The implementation of this function originates from - # https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011 if window < 1: raise ValueError("`window` must be at least 1.") if window > a.shape[axis]: @@ -332,25 +336,26 @@ def rolling_window(a, window=1, step=1, axis=-1): if step < 1: raise ValueError("`step` must be at least 1.") axis = axis % a.ndim - num_windows = (a.shape[axis] - window + step) // step - shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :] - strides = ( - a.strides[:axis] - + (step * a.strides[axis], a.strides[axis]) - + a.strides[axis + 1 :] + array_module = da if isinstance(a, da.Array) else np + steps = tuple( + slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim) ) - rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) - if ma.isMaskedArray(a): - mask = ma.getmaskarray(a) - strides = ( - mask.strides[:axis] - + (step * mask.strides[axis], mask.strides[axis]) - + mask.strides[axis + 1 :] - ) - rw = ma.array( - rw, - mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides), + + def _rolling_window(array): + return array_module.moveaxis( + array_module.lib.stride_tricks.sliding_window_view( + array, + window_shape=window, + axis=axis, + )[steps], + -1, + axis + 1, ) + + rw = _rolling_window(a) + if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray): + mask = _rolling_window(array_module.ma.getmaskarray(a)) + rw = array_module.ma.masked_array(rw, mask) return rw