-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pallas] When mixing basic indexing and integer array indexing, the axis corresponding to integer array indexing is unnecessarily moved to the front #22783
Labels
bug
Something isn't working
Comments
Better repro (without strided indexing): import functools
import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
import numpy as np
x_shape = (16, 3)
x = jnp.arange(np.prod(x_shape), dtype=jnp.float32).reshape(x_shape)
a = jnp.array([1, 1, 1, 1, 1], dtype=jnp.int32)
y = x[:, a]
@functools.partial(
pl.pallas_call,
out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32),
interpret=True,
)
def kernel(x_ref, o_ref):
o_ref[...] = x_ref[:, a]
y_ = kernel(x)
np.testing.assert_array_equal(y_, y) Error:
|
copybara-service bot
pushed a commit
that referenced
this issue
Sep 19, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Sep 19, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Sep 19, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Sep 19, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Sep 19, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Sep 19, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Sep 20, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Sep 20, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Oct 7, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Oct 8, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Oct 8, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Oct 23, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
copybara-service bot
pushed a commit
that referenced
this issue
Oct 29, 2024
Fixes #22783 PiperOrigin-RevId: 676368116
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Description
I am testing in interpret mode.
Repro:
Expected behavior:
The line
y_ = kernel(x)
should run successfully, and yield the same value asy
.Actual behavior:
Explanation:
The correct shape of the resulting array should be (4, 5), but in Pallas, the shape is incorrectly assumed to be (5, 4), thus resulting the error.
I have tested various indexing and observed a pattern that when there is only 1 integer array indexing, the axis corresponding to it is always unnecessarily moved to the front. For example, in the above case, the axis with shape 5 is moved to the front, making Pallas to assume the shape to be (5, 4) instead of (4, 5).
This may have to do with https://github.com/google/jax/blob/5c9bb612a775ca23d311eef1aeac03dfe0828a62/jax/_src/state/indexing.py#L256-L257.
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: