Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

no api: numpy.lib.stride_tricks.as_strided #11354

Open
fmscole opened this issue Jul 2, 2022 · 2 comments
Open

no api: numpy.lib.stride_tricks.as_strided #11354

fmscole opened this issue Jul 2, 2022 · 2 comments
Labels
enhancement New feature or request

Comments

@fmscole
Copy link

fmscole commented Jul 2, 2022

Please support the api: numpy.lib.stride_tricks.as_strided.
as_strided() is powerful,it is the array's view.
It can do every thing with einsum(), such as conv .

@fmscole fmscole added the enhancement New feature or request label Jul 2, 2022
@mattjj
Copy link
Collaborator

mattjj commented Jul 3, 2022

Thanks for the request!

I love as_strided in NumPy. I used it a ton in grad school. And I take your point about convolutions: in fact, if you look at the file lax_reference.py, which is a NumPy-based implementation of the basic JAX primitives (used for testing against), you can see the convolutions are implemented using stride_tricks composed with einsum. However, in the conv case, the computation that technique generates isn't as efficient as calling a dedicated convolutional kernel. (If it were, folks wouldn't spend so long hand-optimizing or compiler-engineering for convolutions!)

So while you can express conv using as_strided composed with einsum, it may not be very broadly useful for performance reasons.

More generally, as_strided isn't as relevant in JAX as it is in, say, NumPy because while NumPy provides an operational semantics in terms of memory buffers (and hence e.g. views), JAX's (and XLA's) functional semantics are based on values. That is, even setting aside fancy strides, a simple expression like x[:2] produces a new value for which the ultimate physical representation is not guaranteed to be related to (e.g. a view of) the physical representation of x. Choices of what to materialize in memory, how to lay it out, etc. are all left to the compiler.

For example, in classic NumPy the programmer is in charge of execution details, so if we're evaluating something like (x[:10].max(), x[:10].min()) we might like to know that the x[:10] subexpressions aren't going to cause any copies. Instead, we're going to get a view on the original array x and just loop over it twice (once to compute the max and once to compute the min).

But with JAX/XLA, under a jit, not only would we avoid creating copies or intermediate arrays, but the two loops would likely be fused together so that we compute both the max and the min in a single pass over the first 10 elements of x. Those kinds of operational optimizations are only possible because we're not telling the computer exactly what to do, and instead giving the compiler freedom to choose how things are computed.

All that is to say that while we could provide an as_strided API, it wouldn't have the same "how this computation is performed" guarantees that NumPy has. But it still may be useful from an expressiveness point of view (e.g. to build 'windowed views' of data, in a way that's familiar to NumPy experts).

After writing all that, I just noticed that #3171 is essentially the same feature request, and I left a very similar answer there (but more than 2 years ago). So take a look at that for a little more info!

But instead of closing the request like #3171, I think we should consider adding an as_strided API, with clear warnings about the lack of operational semantics guarantees. It may be a lower priority item though...

What do you think?

@fmscole
Copy link
Author

fmscole commented Jul 3, 2022

Thanks for your reply so much! You are so patient!Nice!@mattjj

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants