-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Odd output from jax.ops.index_update when jitted #7461
Comments
I verified that with
the results match. Maybe something wrong with our jit. |
It also doesn't repro on GPU. |
Filed bug with XLA:CPU |
I think the issue I just wanted to file is very much related to this one. In a large model with wrong JVP and VJP I was able to trim down the code to wrong from jax import numpy as jnp
from jax import jit
def _debug_twolog_int(x):
from jax.ops import index, index_update
twolog = jnp.zeros((2 + x.shape[1], ))
# Required as otherwise does not break
twolog = index_update(twolog, ((0, 1), ), 0.)
# assert jnp.all(twolog == 0.) # Always true if not compiled
# Does not work! wtf?!!!!
# JAX does not seem to like `twolog[1:-1]`, though slices w/o negative
# indices seem to work
twolog = index_update(twolog, index[2:], twolog[0:-2] + x[0])
# twolog = index_update(twolog, index[2:], x[0]) # Works
return twolog
r = jnp.repeat(jnp.arange(12).reshape((1, -1)), 2, axis=0)
print("W/O JIT:", _debug_twolog_int(r))
print("W JIT:", jit(_debug_twolog_int)(r))
print("DIFF 01:", jit(_debug_twolog_int)(r) - _debug_twolog_int(r)) Changing any line in the above resolves the issue. However, it is pretty easy to break it again, e.g. via def _debug_yet_another_twolog_int(x):
from jax.ops import index, index_update
twolog = jnp.zeros((2 + x.shape[1], ))
twolog = index_update(twolog, index[0], 0.)
twolog = index_update(twolog, index[1], 0.)
# assert jnp.all(twolog == 0.) # Always true if not compiled
# twolog = index_update(twolog, index[2:], x[0]) # Works
twolog = index_update(twolog, index[2:], x[1])
# Does not work! wtf?!!!!
# JAX does not seem to like `twolog[1:-1]`, though slices w/o negative
# indices seem to work
twolog = index_update(twolog, index[2:], twolog[1:-1] + x[0])
return twolog
r = jnp.repeat(jnp.arange(12).reshape((1, -1)), 2, axis=0)
print(
"DIFF 02:",
jit(_debug_yet_another_twolog_int)(r) - _debug_yet_another_twolog_int(r)
) |
Prior versions of JAX and jaxlib such as "jax[cpu]==0.2.13" and "jaxlib==0.1.65" seem to be affected as well. Note, I did not perform a proper git bisect but instead just looped over |
Thanks for looking into this! (hi Matt!) |
For any googlers reading this, the internal P1 bug is b/195462810. We'll update this thread when progress is made on the bug. The best workaround I can think of for now is to use GPU or TPU backends. |
I am terribly sorry for bumping this thread even though you said you would update it upon any progress. However, since this is blocking part of the development on our side, the workaround being impossible (our problems do not fit on a GPU), and considering that seemingly no progress has been made, I would like to ask whether the priority of the XLA issue can be increased. The error seems quite severe to me and can have dangerous silent breaking effects. |
Hi @Edenhofer - unfortunately there has not been any progress on this as of yet. It's turned out to be a difficult issue to pin down. |
Debugging progress is happening. We suspect that this isn't a CPU-only bug, it's just much more likely to exhibit on CPU. Watch this space! |
We have a candidate fix for this issue that hopefully should land soon. |
… a risk of clobbering data A problematic pattern is like the following, where the in-place update of parameter.1 can clobber the input of slice, so the slice and dynamic-update-slice shouldn't be fused together. parameter.1 = f32[8] parameter(0) slice.19 = f32[7] slice(parameter.1), slice={[0:7]} constant.7 = f32[] constant(1) broadcast.8 = f32[7] broadcast(constant.7), dimensions={} add.9 = f32[7] add(slice.19, broadcast.8) constant.10 = s32[] constant(1) ROOT dynamic-update-slice.1 = f32[8] dynamic-update-slice(parameter.1, add.9, constant.10) We explicitly allow patterns of slice or dynamic-slice feeding to dynamic-update-slice if we can prove that the two use the same indices. jax-ml/jax#7461 PiperOrigin-RevId: 401589870 Change-Id: I8695bd706e702e1356dc4e1d9729def6293b0fb2
I can confirm that updating from |
@Edenhofer Yes, that's right, jaxlib 0.1.72 contains the fix for this issue. Sorry this took a while to track down! |
The text was updated successfully, but these errors were encountered: