You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When I try a simple scan loop using Jax in 64bit on a Windows 10 machine, I get the following error.
import jax
import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)
T = 1.0
init_x = 0.0
init_t = 0.0
dt = T / num_steps
def scan_fun(x, t):
x += dt
t += dt
return x, t
xs = jnp.linspace(init_t, T, num_steps, endpoint=True)
jax.lax.scan(scan_fun, init_x, xs[:, jnp.newaxis])[0]
Crashes with the following error
Traceback (most recent call last):
File "C:\Program Files\JetBrains\PyCharm 2020.3.3\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_exec2.py", line 3, in Exec
exec(exp, global_vars, local_vars)
File "<string>", line 13, in <module>
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1292, in scan
out = scan_p.bind(*itertools.chain(consts, in_flat),
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1880, in scan_bind
return core.Primitive.bind(scan_p, *args, **params)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 284, in bind
out = top_trace.process_primitive(self, tracers, params)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 622, in process_primitive
return primitive.impl(*tracers, **params)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1379, in _scan_impl
return _scan_impl_loop(
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1340, in _scan_impl_loop
_, *outs = while_loop(cond_fun, body_fun, init_val)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 284, in while_loop
init_vals, init_avals, body_jaxpr, in_tree, *rest = _create_jaxpr(init_val)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 269, in _create_jaxpr
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\util.py", line 198, in wrapper
return cached(bool(config.x64_enabled), *args, **kwargs)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\util.py", line 191, in cached
return f(*args, **kwargs)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 73, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\util.py", line 198, in wrapper
return cached(bool(config.x64_enabled), *args, **kwargs)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\util.py", line 191, in cached
return f(*args, **kwargs)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 68, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\partial_eval.py", line 1190, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\partial_eval.py", line 1200, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\control_flow.py", line 1324, in cond_fun
return i < length
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 516, in __lt__
def __lt__(self, other): return self.aval._lt(self, other)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 5256, in deferring_binary_op
return binary_op(self, other)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 466, in fn
return lax_fn(x1, x2)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 407, in lt
return lt_p.bind(x, y)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 284, in bind
out = top_trace.process_primitive(self, tracers, params)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\partial_eval.py", line 1062, in process_primitive
out_avals = primitive.abstract_eval(*avals, **params)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 1994, in standard_abstract_eval
return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 2070, in naryop_dtype_rule
_check_same_dtypes(name, False, *aval_dtypes)
File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 6158, in _check_same_dtypes
raise TypeError(msg.format(name, ", ".join(map(str, types))))
TypeError: lt requires arguments to have the same dtypes, got int64, int32.
The text was updated successfully, but these errors were encountered:
When I try a simple scan loop using Jax in 64bit on a Windows 10 machine, I get the following error.
Crashes with the following error
The text was updated successfully, but these errors were encountered: