Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Odd output from jax.ops.index_update when jitted #7461

Closed
yovadia opened this issue Aug 3, 2021 · 14 comments · Fixed by #8198
Closed

Odd output from jax.ops.index_update when jitted #7461

yovadia opened this issue Aug 3, 2021 · 14 comments · Fixed by #8198
Assignees
Labels
bug Something isn't working

Comments

@yovadia
Copy link

yovadia commented Aug 3, 2021

import jax
import jax.numpy as jnp

def demo(n=8):
  fn = lambda x: jax.ops.index_update(x, slice(1, None), 1 + x[:-1])
  y = jnp.zeros(n)
  print(fn(y))
  print(jax.jit(fn)(y))
demo()

# [0. 1. 1. 1. 1. 1. 1. 1.]
# [0. 1. 2. 3. 4. 5. 6. 7.]
@yovadia yovadia added the bug Something isn't working label Aug 3, 2021
@zhangqiaorjc
Copy link
Collaborator

I verified that with

with api.disable_jit():
  print(jax.jit(fn)(y))

the results match.

Maybe something wrong with our jit.

@mattjj
Copy link
Collaborator

mattjj commented Aug 3, 2021

(Hey Yaniv!)

I suspect this is an XLA:CPU bug. It doesn't reproduce on TPU:
image

(It seems like a plausible bug too: inputs and outputs are incorrectly aliased so we're e.g. writing the new x[1] before reading it for the update to x[2].)

@mattjj
Copy link
Collaborator

mattjj commented Aug 3, 2021

It also doesn't repro on GPU.

@zhangqiaorjc
Copy link
Collaborator

Filed bug with XLA:CPU

@Edenhofer
Copy link
Contributor

Edenhofer commented Aug 4, 2021

I think the issue I just wanted to file is very much related to this one. In a large model with wrong JVP and VJP I was able to trim down the code to wrong index_update calls:

from jax import numpy as jnp
from jax import jit


def _debug_twolog_int(x):
    from jax.ops import index, index_update

    twolog = jnp.zeros((2 + x.shape[1], ))

    # Required as otherwise does not break
    twolog = index_update(twolog, ((0, 1), ), 0.)
    # assert jnp.all(twolog == 0.)  # Always true if not compiled
    # Does not work! wtf?!!!!
    # JAX does not seem to like `twolog[1:-1]`, though slices w/o negative
    # indices seem to work
    twolog = index_update(twolog, index[2:], twolog[0:-2] + x[0])
    # twolog = index_update(twolog, index[2:], x[0])  # Works
    return twolog


r = jnp.repeat(jnp.arange(12).reshape((1, -1)), 2, axis=0)
print("W/O JIT:", _debug_twolog_int(r))
print("W   JIT:", jit(_debug_twolog_int)(r))
print("DIFF 01:", jit(_debug_twolog_int)(r) - _debug_twolog_int(r))

Changing any line in the above resolves the issue. However, it is pretty easy to break it again, e.g. via

def _debug_yet_another_twolog_int(x):
    from jax.ops import index, index_update

    twolog = jnp.zeros((2 + x.shape[1], ))

    twolog = index_update(twolog, index[0], 0.)
    twolog = index_update(twolog, index[1], 0.)
    # assert jnp.all(twolog == 0.)  # Always true if not compiled
    # twolog = index_update(twolog, index[2:], x[0])  # Works
    twolog = index_update(twolog, index[2:], x[1])
    # Does not work! wtf?!!!!
    # JAX does not seem to like `twolog[1:-1]`, though slices w/o negative
    # indices seem to work
    twolog = index_update(twolog, index[2:], twolog[1:-1] + x[0])
    return twolog


r = jnp.repeat(jnp.arange(12).reshape((1, -1)), 2, axis=0)
print(
    "DIFF 02:",
    jit(_debug_yet_another_twolog_int)(r) - _debug_yet_another_twolog_int(r)
)

@Edenhofer
Copy link
Contributor

Edenhofer commented Aug 4, 2021

Prior versions of JAX and jaxlib such as "jax[cpu]==0.2.13" and "jaxlib==0.1.65" seem to be affected as well. Note, I did not perform a proper git bisect but instead just looped over pip install calls and executed the above lines in a script.

@yovadia
Copy link
Author

yovadia commented Aug 4, 2021

Thanks for looking into this! (hi Matt!)

@mattjj
Copy link
Collaborator

mattjj commented Aug 6, 2021

For any googlers reading this, the internal P1 bug is b/195462810. We'll update this thread when progress is made on the bug.

The best workaround I can think of for now is to use GPU or TPU backends.

@Edenhofer
Copy link
Contributor

I am terribly sorry for bumping this thread even though you said you would update it upon any progress.

However, since this is blocking part of the development on our side, the workaround being impossible (our problems do not fit on a GPU), and considering that seemingly no progress has been made, I would like to ask whether the priority of the XLA issue can be increased. The error seems quite severe to me and can have dangerous silent breaking effects.

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 28, 2021

Hi @Edenhofer - unfortunately there has not been any progress on this as of yet. It's turned out to be a difficult issue to pin down.

@hawkinsp
Copy link
Collaborator

Debugging progress is happening. We suspect that this isn't a CPU-only bug, it's just much more likely to exhibit on CPU. Watch this space!

@hawkinsp
Copy link
Collaborator

hawkinsp commented Oct 7, 2021

We have a candidate fix for this issue that hopefully should land soon.

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Oct 7, 2021
… a risk of clobbering data

A problematic pattern is like the following, where the in-place update of
parameter.1 can clobber the input of slice, so the slice and
dynamic-update-slice shouldn't be fused together.

parameter.1 = f32[8] parameter(0)
slice.19 = f32[7] slice(parameter.1), slice={[0:7]}
constant.7 = f32[] constant(1)
broadcast.8 = f32[7] broadcast(constant.7), dimensions={}
add.9 = f32[7] add(slice.19, broadcast.8)
constant.10 = s32[] constant(1)
ROOT dynamic-update-slice.1 = f32[8] dynamic-update-slice(parameter.1, add.9, constant.10)

We explicitly allow patterns of slice or dynamic-slice feeding to
dynamic-update-slice if we can prove that the two use the same indices.

jax-ml/jax#7461

PiperOrigin-RevId: 401589870
Change-Id: I8695bd706e702e1356dc4e1d9729def6293b0fb2
@Edenhofer
Copy link
Contributor

Edenhofer commented Oct 13, 2021

I can confirm that updating from jaxlib==0.1.71 to jaxlib==0.1.72 resolves the issue I had with index_update.

@hawkinsp
Copy link
Collaborator

@Edenhofer Yes, that's right, jaxlib 0.1.72 contains the fix for this issue. Sorry this took a while to track down!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants