Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Equivalent of np.lib.stride_tricks.as_strided #3171

Closed
Joshuaalbert opened this issue May 21, 2020 · 4 comments
Closed

Equivalent of np.lib.stride_tricks.as_strided #3171

Joshuaalbert opened this issue May 21, 2020 · 4 comments
Assignees
Labels
question Questions for the JAX team

Comments

@Joshuaalbert
Copy link
Contributor

Does it already exists, or would it be possible to have np.lib.stride_tricks.as_strided in JAX?
I use it to build views of arrays with a rolling windowed last axis as below.
This then enables a quick calcuation of any rolling statistics.

def rolling_window(a, window, padding='same'):
    """
    Produces a rolling window view of array.

    Args:
        a: ndarray
            Array to produce a rolling view over. The rolling view is over the last axis.
        window: int
            Size of rolling window
        padding: str
            Type of padding, if 'same' then using reflecting padding so that rolling axis stays the same length.

    Returns: ndarray
        If `a` is shape [..., T] this returns an array of shape [..., T, window] if padding == 'same'
        otherwise a shape [..., T - window + 1, window]

    Examples:
        >>> a = np.arange(5) #0,1,2,3,4
        >>> print(rolling_window(a,3, padding='same'))
        [[1 0 1]
         [0 1 2]
         [1 2 3]
         [2 3 4]
         [3 4 3]]
        >>> print(rolling_window(a,3,padding=None))
        [[0 1 2]
         [1 2 3]
         [2 3 4]]
        # Using it to perform rolling mean
        >>> b = rolling_window(a,3,padding='same')
        >>> print(np.mean(b, axis=-1))
        [0.66666667 1.         2.         3.         3.33333333]

    """
    if padding.lower() == 'same':
        pad_start = np.zeros(len(a.shape), dtype=np.int32)
        pad_start[-1] = window // 2
        pad_end = np.zeros(len(a.shape), dtype=np.int32)
        pad_end[-1] = (window - 1) - pad_start[-1]
        pad = list(zip(pad_start, pad_end))
        a = np.pad(a, pad, mode='reflect')
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
@mattjj mattjj added the question Questions for the JAX team label May 21, 2020
@mattjj mattjj self-assigned this May 21, 2020
@mattjj
Copy link
Collaborator

mattjj commented May 21, 2020

Thanks for the question! I've used as_strided a lot in NumPy as well (like here and here).

There are two reasons (at least) you might want to use as_strided: first for efficiency (e.g. don't allocate a buffer for the rolling window array, make it just a view of the existing buffer, and moreover get better cache locality that way), and second for a convenient way of expressing the rolling window array (whether it formed a copy or not).

Under a jit, i.e. in an XLA computation, the notion of "buffer" and hence "view into the same buffer" doesn't exist; instead, there are only values. That's quite different from the NumPy buffer semantics: for example, in NumPy we know when new buffers are allocated (and know how to avoid it, e.g. with stride_tricks) and rely on that knowledge to understand how in-place updating of the values in the buffer affect aliases. In contrast, in XLA value semantics, which values in our computation are backed by buffers is totally up to the compiler; for example, when XLA decides to fuse operations, that means none of the temporary values among those operations are ever backed by a buffer. A corollary is that not only does "buffer" and "view into the same buffer" not exist under a jit, but also we don't need it to for efficiency. (Another optimization XLA can do underneath its buffer-free programming model is layout optimization.)

Outside of a jit, i.e. doing op-by-op computation in Python, it's a different story. Each DeviceArray is backed by a device buffer. While we don't allow in-place updating of the values in the buffer (we treat buffers as immutable), still there are some operations that avoid performing a copy, like reshapes and transposes, and instead just update metadata (see #1668). That's something like forming "views" of the underlying buffer, except for the fact that this sublanguage isn't rich enough to express slicing or striding like NumPy's. We considered growing the language to NumPy-style views, but decided that's probably not worth the complexity, as jit is our main performance model anyway.

All that is to say that for efficiency as_strided would only make sense in op-by-op (eager) mode, not under a jit, and also that we have chosen not to support it in op-by-op mode for now.

For providing a convenient way of expressing rolling windows, we could consider adding some kind of "strided view" helper function (which would always perform a copy in op-by-op evaluation), but on the other hand it might be easy enough to express these kinds of values via alternative means. One way to do it would be just to do some padding, reshaping, and slicing. Another way might be to use a gather operation like this (I didn't put in the padding logic):

import numpy as jnp


def rolling_window(a: jnp.ndarray, window: int):
  idx = jnp.arange(len(a) - window + 1)[:, None] + jnp.arange(window)[None, :]
  return a[idx]

a = jnp.arange(5)
print(rolling_window(a, 3))
[[0 1 2]
 [1 2 3]
 [2 3 4]]

WDYT?

@Joshuaalbert
Copy link
Contributor Author

Thanks for the in-depth reply! I also think choosing to develop jit features over buffer features is the way to go. Hmm, not sure you need worry about helpers for things like stride views unless it becomes something other people are looking for. Using slicing and indexing is a pretty good alternative for rolling windows.

@erdmann
Copy link

erdmann commented May 28, 2022

In case it's helpful to anyone coming here and looking for a jitted moving window into a 1D array, something like this works (also without padding logic):

from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap

@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
    starts = jnp.arange(len(a) - size + 1)
    return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)

a = jnp.arange(10)
print(moving_window(a, 4))
[[0 1 2 3]
 [1 2 3 4]
 [2 3 4 5]
 [3 4 5 6]
 [4 5 6 7]
 [5 6 7 8]
 [6 7 8 9]]

@Gintasz
Copy link

Gintasz commented Mar 26, 2023

Modified the erdmann's code a bit for 2D array.

from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap

@partial(jit, static_argnums=(1,))
def moving_window(matrix, window_shape):
    matrix_width = matrix.shape[1]
    matrix_height = matrix.shape[0]

    window_width = window_shape[0]
    window_height = window_shape[1]

    startsx = jnp.arange(matrix_width - window_width + 1)
    startsy = jnp.arange(matrix_height - window_height + 1)
    starts_xy = jnp.dstack(jnp.meshgrid(startsx, startsy)).reshape(-1, 2) # cartesian product => [[x,y], [x,y], ...]

    return vmap(lambda start: jax.lax.dynamic_slice(matrix, (start[1], start[0]), (window_height, window_width)))(starts_xy)

matrix = jnp.asarray([
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12]
])
print(moving_window(matrix, (2, 3))) # window width = 2, window height = 3
[[[ 1  2  3]
  [ 5  6  7]]

 [[ 2  3  4]
  [ 6  7  8]]

 [[ 5  6  7]
  [ 9 10 11]]

 [[ 6  7  8]
  [10 11 12]]]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

4 participants