Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.experimental.enable_x64 fails in JIT context #5982

Open
jakevdp opened this issue Mar 8, 2021 · 4 comments
Open

jax.experimental.enable_x64 fails in JIT context #5982

jakevdp opened this issue Mar 8, 2021 · 4 comments
Assignees
Labels
bug Something isn't working

Comments

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 8, 2021

minimal repro:

from jax import jit, experimental

@jit
def f(a):
  with experimental.enable_x64():
    return 1 + a
f(1)
RuntimeError: Invalid argument: Binary op add with different element types: s64[] and s32[].
@jakevdp jakevdp added the bug Something isn't working label Mar 8, 2021
@jakevdp jakevdp self-assigned this Mar 8, 2021
@jakevdp
Copy link
Collaborator Author

jakevdp commented Mar 10, 2021

Happily, it looks like #6014 mostly fixes this! The issue came from the fact that convert_element_type was not being bound at trace time (because the constant is 64-bit in that context), though it was required at runtime (because the constant is 32-bit in that context).

@jakevdp
Copy link
Collaborator Author

jakevdp commented Mar 11, 2021

Here's a repro that's closer to the metal, fixed by #6018

import jax.numpy as jnp
from jax import experimental, jit, lax

# Create an int64 in X32 mode
with experimental.enable_x64():
  x = jnp.int64(1)

y = lax.convert_element_type(x, jnp.int32)
assert y.dtype == jnp.int32  # fails

z = jit(lax.convert_element_type, static_argnums=1)(x, jnp.int32)
assert z.dtype == jnp.int32  # fails

@btanner
Copy link

btanner commented Apr 8, 2021

Bumping this issue with another simple example:

import jax
import jax.experimental
import jax.numpy as jnp

def foo(x):
  with jax.experimental.enable_x64():
    clipped = x.astype(jnp.uint64)
    clipped = clipped % 2

    # Notably, the line below gives the identical error
    # clipped = clipped % jnp.asarray(2, jnp.uint64)

    return clipped

x = jnp.asarray(.5)

print(foo(x))
j_foo = jax.jit(foo)
print(j_foo(x))

The non-jitted call works fine.

Here is the error:

google3/third_party/py/jax/interpreters/xla.py in jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args)
    462 
    463     assert isinstance(ans, xe.XlaOp)
--> 464     c.get_shape(ans)  # force xla to do shape error checking
    465     if eqn.primitive.multiple_results or any(v.aval._num_buffers > 1 for v in eqn.outvars):
    466       out_nodes = xla_destructure(c, ans)

RuntimeError: Invalid argument: Binary op remainder with different element types: u64[] and u32[].

@mattjj
Copy link
Collaborator

mattjj commented Apr 8, 2021

Copying some comments from a chat with @btanner , for posterity:

i think the basic issue is that when using jit some things are delayed until after we've left the contextmanager's scope. in this case, it's the processing of the literal 2, which ultimately happens here: https://github.com/google/jax/blob/035c907d637bb6288740d1c2bc70178a5ab3bcd0/jax/interpreters/xla.py#L410

since we're out of the contextmanager's scope, we canonicalize the dtype from a Python int to an int32. but that's incompatible with the jaxpr computation, specifically the remainder op, that jit built while the context manager was in effect

i think to make this contextmanager work well, we need not to call canonicalize_dtype in places like this (which depends on the value of the x64_enabled global), and instead to set the dtype according to whatever the jaxpr needs

that is, when we put arguments on devices, we need to just look at the dtype the jaxpr wants. similarly, when we move constants or literals to the device, we need to look at the dtype the jaxpr wants. in neither case should we look at the global. that'll give us the kind of control we want, since the contextmanager controls what gets staged out into the jaxpr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants