-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Equivalent of np.lib.stride_tricks.as_strided #3171
Comments
Thanks for the question! I've used There are two reasons (at least) you might want to use Under a Outside of a All that is to say that for efficiency For providing a convenient way of expressing rolling windows, we could consider adding some kind of "strided view" helper function (which would always perform a copy in op-by-op evaluation), but on the other hand it might be easy enough to express these kinds of values via alternative means. One way to do it would be just to do some padding, reshaping, and slicing. Another way might be to use a gather operation like this (I didn't put in the padding logic): import numpy as jnp
def rolling_window(a: jnp.ndarray, window: int):
idx = jnp.arange(len(a) - window + 1)[:, None] + jnp.arange(window)[None, :]
return a[idx]
a = jnp.arange(5)
print(rolling_window(a, 3))
WDYT? |
Thanks for the in-depth reply! I also think choosing to develop jit features over buffer features is the way to go. Hmm, not sure you need worry about helpers for things like stride views unless it becomes something other people are looking for. Using slicing and indexing is a pretty good alternative for rolling windows. |
In case it's helpful to anyone coming here and looking for a jitted moving window into a 1D array, something like this works (also without padding logic): from functools import partial
import jax
import jax.numpy as jnp
from jax import jit, vmap
@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
starts = jnp.arange(len(a) - size + 1)
return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)
a = jnp.arange(10)
print(moving_window(a, 4))
|
Modified the erdmann's code a bit for 2D array.
|
Does it already exists, or would it be possible to have
np.lib.stride_tricks.as_strided
in JAX?I use it to build views of arrays with a rolling windowed last axis as below.
This then enables a quick calcuation of any rolling statistics.
The text was updated successfully, but these errors were encountered: