-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.experimental.enable_x64 fails in JIT context #5982
Comments
Happily, it looks like #6014 mostly fixes this! The issue came from the fact that convert_element_type was not being bound at trace time (because the constant is 64-bit in that context), though it was required at runtime (because the constant is 32-bit in that context). |
Here's a repro that's closer to the metal, fixed by #6018 import jax.numpy as jnp
from jax import experimental, jit, lax
# Create an int64 in X32 mode
with experimental.enable_x64():
x = jnp.int64(1)
y = lax.convert_element_type(x, jnp.int32)
assert y.dtype == jnp.int32 # fails
z = jit(lax.convert_element_type, static_argnums=1)(x, jnp.int32)
assert z.dtype == jnp.int32 # fails |
Bumping this issue with another simple example:
The non-jitted call works fine. Here is the error:
|
Copying some comments from a chat with @btanner , for posterity: i think the basic issue is that when using jit some things are delayed until after we've left the contextmanager's scope. in this case, it's the processing of the literal 2, which ultimately happens here: https://github.com/google/jax/blob/035c907d637bb6288740d1c2bc70178a5ab3bcd0/jax/interpreters/xla.py#L410 since we're out of the contextmanager's scope, we canonicalize the dtype from a Python int to an int32. but that's incompatible with the jaxpr computation, specifically the remainder op, that jit built while the context manager was in effect i think to make this contextmanager work well, we need not to call canonicalize_dtype in places like this (which depends on the value of the x64_enabled global), and instead to set the dtype according to whatever the jaxpr needs that is, when we put arguments on devices, we need to just look at the dtype the jaxpr wants. similarly, when we move constants or literals to the device, we need to look at the dtype the jaxpr wants. in neither case should we look at the global. that'll give us the kind of control we want, since the contextmanager controls what gets staged out into the jaxpr |
minimal repro:
The text was updated successfully, but these errors were encountered: