Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pallas] When mixing basic indexing and integer array indexing, the axis corresponding to integer array indexing is unnecessarily moved to the front #22783

Open
ayaka14732 opened this issue Jul 31, 2024 · 1 comment · May be fixed by #23758
Assignees
Labels
bug Something isn't working

Comments

@ayaka14732
Copy link
Member

Description

I am testing in interpret mode.

Repro:

import functools

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
import numpy as np

x_shape = (16, 3)

x = jnp.arange(np.prod(x_shape), dtype=jnp.float32).reshape(x_shape)

a = jnp.array([1, 1, 1, 1, 1], dtype=jnp.int32)
y = x[::4, a]

@functools.partial(
    pl.pallas_call,
    out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32),
    interpret=True,
)
def kernel(x_ref, o_ref):
    o_ref[...] = x_ref[::4, a]

y_ = kernel(x)
np.testing.assert_array_equal(y_, y)

Expected behavior:

The line y_ = kernel(x) should run successfully, and yield the same value as y.

Actual behavior:

Traceback (most recent call last):
  File "/home/ayx/development/jax/test.py", line 23, in <module>
    y_ = kernel(x)
         ^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 1085, in wrapped
    grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr(
                                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 857, in _trace_kernel_to_jaxpr
    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/test.py", line 21, in kernel
    o_ref[...] = x_ref[::4, a]
    ~~~~~^^^^^
  File "/home/ayx/development/jax/jax/_src/numpy/array_methods.py", line 747, in op
    return getattr(self.aval, f"_{name}")(self, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/types.py", line 187, in _setitem
    return ref_set(tracer, idx, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 124, in ref_set
    ref_swap(ref_or_view, idx, value, _function_name="ref_set")
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 120, in ref_swap
    return swap_p.bind(ref, value, *flat_indexers, tree=tree)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 188, in _swap_abstract_eval
    raise ValueError("Invalid shape for `swap`. "
ValueError: Invalid shape for `swap`. Ref shape: (4, 5). Expected shape: (4, 5). Value shape: (5, 4). Indices: (NDIndexer(indices=(Slice(start=0, size=4, stride=1), Slice(start=0, size=5, stride=1)), shape=(4, 5), int_indexer_shape=(), validate=False),). 
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Explanation:

The correct shape of the resulting array should be (4, 5), but in Pallas, the shape is incorrectly assumed to be (5, 4), thus resulting the error.

I have tested various indexing and observed a pattern that when there is only 1 integer array indexing, the axis corresponding to it is always unnecessarily moved to the front. For example, in the above case, the axis with shape 5 is moved to the front, making Pallas to assume the shape to be (5, 4) instead of (4, 5).

This may have to do with https://github.com/google/jax/blob/5c9bb612a775ca23d311eef1aeac03dfe0828a62/jax/_src/state/indexing.py#L256-L257.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.31.dev20240729+6a7822a73
jaxlib: 0.4.30
numpy:  1.26.4
python: 3.12.4 (main, Jun 12 2024, 19:06:53) [GCC 13.2.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='ayx1', release='6.6.15-2rodete2-amd64', version='#1 SMP PREEMPT_DYNAMIC Debian 6.6.15-2rodete2 (2024-03-19)', machine='x86_64')
@ayaka14732 ayaka14732 added the bug Something isn't working label Jul 31, 2024
@ayaka14732 ayaka14732 self-assigned this Jul 31, 2024
@ayaka14732
Copy link
Member Author

ayaka14732 commented Aug 14, 2024

Better repro (without strided indexing):

import functools

import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl
import numpy as np

x_shape = (16, 3)

x = jnp.arange(np.prod(x_shape), dtype=jnp.float32).reshape(x_shape)

a = jnp.array([1, 1, 1, 1, 1], dtype=jnp.int32)
y = x[:, a]

@functools.partial(
    pl.pallas_call,
    out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32),
    interpret=True,
)
def kernel(x_ref, o_ref):
    o_ref[...] = x_ref[:, a]

y_ = kernel(x)
np.testing.assert_array_equal(y_, y)

Error:

Traceback (most recent call last):
  File "/home/ayx/development/jax/4.py", line 23, in <module>
    y_ = kernel(x)
         ^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 1129, in wrapped
    jaxpr = _trace_kernel_to_jaxpr(
            ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/pallas/pallas_call.py", line 901, in _trace_kernel_to_jaxpr
    jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_kernel_fun,
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/4.py", line 21, in kernel
    o_ref[...] = x_ref[:, a]
    ~~~~~^^^^^
  File "/home/ayx/development/jax/jax/_src/numpy/array_methods.py", line 749, in op
    return getattr(self.aval, f"_{name}")(self, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/types.py", line 187, in _setitem
    return ref_set(tracer, idx, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 114, in ref_set
    ref_swap(ref_or_view, idx, value, _function_name="ref_set")
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 110, in ref_swap
    return swap_p.bind(ref, value, *flat_indexers, tree=tree)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ayx/development/jax/jax/_src/state/primitives.py", line 178, in _swap_abstract_eval
    raise ValueError("Invalid shape for `swap`. "
ValueError: Invalid shape for `swap`. Ref shape: (16, 5). Expected shape: (16, 5). Value shape: (5, 16). Indices: (NDIndexer(indices=(Slice(start=0, size=16, stride=1), Slice(start=0, size=5, stride=1)), shape=(16, 5), int_indexer_shape=(), validate=False),). 
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

copybara-service bot pushed a commit that referenced this issue Sep 19, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
@copybara-service copybara-service bot linked a pull request Sep 19, 2024 that will close this issue
copybara-service bot pushed a commit that referenced this issue Sep 19, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Sep 19, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Sep 19, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Sep 19, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Sep 19, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Sep 20, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Sep 20, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Oct 7, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Oct 8, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Oct 8, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Oct 23, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
copybara-service bot pushed a commit that referenced this issue Oct 29, 2024
Fixes #22783

PiperOrigin-RevId: 676368116
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
1 participant