Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.lax.scan error in 64bit mode #6059

Closed
adam-hartshorne opened this issue Mar 13, 2021 · 1 comment
Closed

jax.lax.scan error in 64bit mode #6059

adam-hartshorne opened this issue Mar 13, 2021 · 1 comment
Labels
bug Something isn't working

Comments

@adam-hartshorne
Copy link

adam-hartshorne commented Mar 13, 2021

When I try a simple scan loop using Jax in 64bit on a Windows 10 machine, I get the following error.

import jax
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)

T = 1.0
init_x = 0.0
init_t = 0.0
dt = T / num_steps
def scan_fun(x, t):
    x += dt
    t += dt
    return x, t
xs = jnp.linspace(init_t, T, num_steps, endpoint=True)
jax.lax.scan(scan_fun, init_x, xs[:, jnp.newaxis])[0]

Crashes with the following error

Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2020.3.3\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<string>", line 13, in <module>
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1292, in scan
    out = scan_p.bind(*itertools.chain(consts, in_flat),
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1880, in scan_bind
    return core.Primitive.bind(scan_p, *args, **params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 284, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 622, in process_primitive
    return primitive.impl(*tracers, **params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1379, in _scan_impl
    return _scan_impl_loop(
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1340, in _scan_impl_loop
    _, *outs = while_loop(cond_fun, body_fun, init_val)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 284, in while_loop
    init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 269, in _create_jaxpr
    cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\util.py", line 198, in wrapper
    return cached(bool(config.x64_enabled), *args, **kwargs)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\util.py", line 191, in cached
    return f(*args, **kwargs)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 73, in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\util.py", line 198, in wrapper
    return cached(bool(config.x64_enabled), *args, **kwargs)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\util.py", line 191, in cached
    return f(*args, **kwargs)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 68, in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\partial_eval.py", line 1190, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\partial_eval.py", line 1200, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1324, in cond_fun
    return i < length
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 516, in __lt__
    def __lt__(self, other): return self.aval._lt(self, other)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 5256, in deferring_binary_op
    return binary_op(self, other)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 466, in fn
    return lax_fn(x1, x2)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 407, in lt
    return lt_p.bind(x, y)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 284, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\partial_eval.py", line 1062, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 1994, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 2070, in naryop_dtype_rule
    _check_same_dtypes(name, False, *aval_dtypes)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 6158, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
TypeError: lt requires arguments to have the same dtypes, got int64, int32.
@adam-hartshorne adam-hartshorne added the bug Something isn't working label Mar 13, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 14, 2021

Thanks for the report – I'm going to close this as a duplicate of #6058, because it's the same underlying issue.

@jakevdp jakevdp closed this as completed Mar 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants