-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: static vs. dynamic shapes and JAX's .at
for simulating in-place ops
#609
Comments
For completeness, let me copy the comparison between syntax's from https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html: |
I'm curious if the Jax developers have thought about what would be needed for Python's syntax to allow the more readable |
For what it's worth, Since JAX arrays do not have mutable view semantics, this is not at all problematic. |
Seems like the tricky case would be |
.at
for simulating in-place ops.at
for simulating in-place ops
This seems to be the only unsolved problem for testing JAX arrays inside SciPy over at scipy/scipy#20085. Lots of these cases occur already in the small portion of the code base which has been ported to array API compatibility. |
Out of curiosity, I tried to run the existing array API test suite of scikit-learn with jax and many of the tests failed because of the inplace assignment limitation of jax making this mostly useless in its current state: |
For anyone following this issue but not my SciPy JAX PR, scipy/scipy#20085 (comment) is (I think) as far as we got with this |
This is a continuation of a discussion that started a few weeks ago in gh-597 (Cc @soraros). It is closely related to gh-84 (boolean indexing) and gh-24 (mutability and copy/views).
I'll copy the content of @soraros's comment here in full:
Start of comment
I also think the problem is more fundamental than that. JAX is essentially a front-end for XLA, and the primitives provided by XLA (for now) require static shape. So the line that actually go wrong is
Note this code does work in JAX, though not jittable, for we don't know its output shape. Let's pretend
x[ix_bool] += 1
is syntax sugar forx = x + where(ix_bool, 1, 0)
(which works in JAX) for a moment. The same problem appears when we wantx[ix_bool] += [1, 3, 5]
. Again, we somehow need to know the shape of the rhs, which is equivalent to know the shape ofxs[ix_bool]
as in the last example.So what we really work around is the static shape requirement (recall the need of a
size
parameter fornonzero
), which is not exclusively JAX.Now, for something a bit off-topic.:
I think the JAX style functional syntax
a = a.at[...].set(...)
for in-place operation looks (and arguably works) better thannumpy
, and I'd really like to have it for array api. Some pros:Numba
as well.Some of my thoughts
boundscheck
keyword to@njit
).x = x.at[...
and numpy et al.'s in-place support is completely equivalent when you have a JIT, and numpy's version is more efficient if you don't - as long as you can guarantee that you are not modifying a view. The syntax is also arguably nicer - more concise and more familiar. So, from that perspective,.at
isn't ideal.The last point is important. Writing generic code is difficult now when you need, e.g., update values with a mask. Doing that only the JAX way seems like a nonstarter, because it's way too inefficient for NumPy et al. The question though is if there's something that would work for JAX, TF and Dask? Dask also struggles to some extent with dynamic shapes, although most of it now works (xref dask/dask#2000 and dask/dask#7393). @jakirkham any thoughts on whether you need anything more (possibly JAX-like) for dynamic shape support in Dask?
The text was updated successfully, but these errors were encountered: