Skip to content

Commit

Permalink
Lazy rolling_window
Browse files Browse the repository at this point in the history
  • Loading branch information
bouweandela committed Feb 23, 2024
1 parent 2b024aa commit bb2ec46
Showing 1 changed file with 25 additions and 20 deletions.
45 changes: 25 additions & 20 deletions lib/iris/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This file is part of Iris and is released under the BSD license.
# See LICENSE in the root of the repository for full licensing details.
"""Miscellaneous utility functions."""
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from collections.abc import Hashable, Iterable
Expand Down Expand Up @@ -281,7 +282,12 @@ def guess_coord_axis(coord):
return axis


def rolling_window(a, window=1, step=1, axis=-1):
def rolling_window(
a: np.ndarray | da.Array,
window: int = 1,
step: int = 1,
axis: int = -1,
) -> np.ndarray | da.Array:
"""Make an ndarray with a rolling window of the last dimension.
Parameters
Expand Down Expand Up @@ -322,34 +328,33 @@ def rolling_window(a, window=1, step=1, axis=-1):
See more at :doc:`/userguide/real_and_lazy_data`.
"""
# NOTE: The implementation of this function originates from
# https://github.com/numpy/numpy/pull/31#issuecomment-1304851 04/08/2011
if window < 1:
raise ValueError("`window` must be at least 1.")
if window > a.shape[axis]:
raise ValueError("`window` is too long.")
if step < 1:
raise ValueError("`step` must be at least 1.")
axis = axis % a.ndim
num_windows = (a.shape[axis] - window + step) // step
shape = a.shape[:axis] + (num_windows, window) + a.shape[axis + 1 :]
strides = (
a.strides[:axis]
+ (step * a.strides[axis], a.strides[axis])
+ a.strides[axis + 1 :]
array_module = da if isinstance(a, da.Array) else np
steps = tuple(
slice(None, None, step) if i == axis else slice(None) for i in range(a.ndim)
)
rw = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
if ma.isMaskedArray(a):
mask = ma.getmaskarray(a)
strides = (
mask.strides[:axis]
+ (step * mask.strides[axis], mask.strides[axis])
+ mask.strides[axis + 1 :]
)
rw = ma.array(
rw,
mask=np.lib.stride_tricks.as_strided(mask, shape=shape, strides=strides),

def _rolling_window(array):
return array_module.moveaxis(
array_module.lib.stride_tricks.sliding_window_view(
array,
window_shape=window,
axis=axis,
)[steps],
-1,
axis + 1,
)

rw = _rolling_window(a)
if isinstance(da.utils.meta_from_array(a), np.ma.MaskedArray):
mask = _rolling_window(array_module.ma.getmaskarray(a))
rw = array_module.ma.masked_array(rw, mask)
return rw


Expand Down

0 comments on commit bb2ec46

Please sign in to comment.