Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross Product Error in 64bit mode #6058

Closed
adam-hartshorne opened this issue Mar 13, 2021 · 5 comments
Closed

Cross Product Error in 64bit mode #6058

adam-hartshorne opened this issue Mar 13, 2021 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@adam-hartshorne
Copy link

When I try a simple cross product calculation using Jax in 64bit on a Windows 10 machine, I get the following error.

import jax.numpy as jnp
from jax.config import config
config.update('jax_enable_x64', True)

x = jnp.array((1.0,1.0,1.0))
y =  jnp.array((0.0,0.0,0.0))
z = jnp.cross(x,y)
Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2020.3.3\plugins\python\helpers\pydev\_pydevd_bundle\pydevd_exec2.py", line 3, in Exec
    exec(exp, global_vars, local_vars)
  File "<input>", line 1, in <module>
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 3938, in cross
    return _cross(a, b, axisa, axisb, axisc)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\api.py", line 416, in f_jitted
    return cpp_jitted_f(context, *args, **kwargs)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\api.py", line 297, in cache_miss
    out_flat = xla.xla_call(
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 1394, in bind
    return call_bind(self, fun, *args, **params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 1385, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 1397, in process
    return trace.process_call(self, fun, tracers, params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 625, in process_call
    return primitive.impl(f, *tracers, **params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\xla.py", line 586, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\linear_util.py", line 260, in memoized_fun
    ans = call(fun, *args)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\xla.py", line 662, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\partial_eval.py", line 1220, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\partial_eval.py", line 1200, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 3922, in _cross
    a0 = a[..., 0]
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 552, in __getitem__
    def __getitem__(self, idx): return self.aval._getitem(self, idx)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4384, in _rewriting_take
    return _gather(arr, treedef, static_idx, dynamic_idx)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4391, in _gather
    indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4571, in _index_to_gather
    i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\numpy\lax_numpy.py", line 4224, in _normalize_index
    lax.lt(index, _constant_like(index, 0)),
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 407, in lt
    return lt_p.bind(x, y)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\core.py", line 284, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\interpreters\partial_eval.py", line 1062, in process_primitive
    out_avals = primitive.abstract_eval(*avals, **params)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 1994, in standard_abstract_eval
    return ShapedArray(shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 2070, in naryop_dtype_rule
    _check_same_dtypes(name, False, *aval_dtypes)
  File "C:\Users\Adam\anaconda3\envs\tensorflow\lib\site-packages\jax\_src\lax\lax.py", line 6158, in _check_same_dtypes
    raise TypeError(msg.format(name, ", ".join(map(str, types))))
TypeError: lt requires arguments to have the same dtypes, got int64, int32.
@adam-hartshorne adam-hartshorne added the bug Something isn't working label Mar 13, 2021
@adam-hartshorne adam-hartshorne changed the title Cross Product in 64bit mode. Cross Product Error in 64bit mode Mar 13, 2021
@jakevdp jakevdp self-assigned this Mar 14, 2021
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 14, 2021

This does not reproduce on OSX or Linux. I suspect this is related to signed integers defaulting to 32-bit in windows, combined with the issue in #6051.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 15, 2021

OK, here's the issue: for weak-typed integers, jnp.result_type returns int:

>>> x = lax.convert_element_type(1, 'int64', weak_type=True)
>>> jnp.result_type(x)
int

On linux and osx, numpy integers default to 64-bit:

>>> np.array(1, dtype=int).dtype
numpy.int64

On some windows machines, they default to 32-bit:

>>> np.array(1, dtype=int).dtype
numpy.int32

In this case, operations like the following will fail

>>> x < 1
TypeError: lt requires arguments to have the same dtypes, got int64, int32.

I think the solution is essentially the fix for #6051.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 16, 2021

Hi @oracle3001 – with #6068 merged, I believe this problem should be fixed. Can you install jax from the master branch and check if that is the case? Thanks!

@adam-hartshorne
Copy link
Author

Hi @jakevdp, unless this issue, #5985, has also been fixed, building jaxlib after v0.1.61 on Windows 10 is current broken due to changes in handling GPU/TPU, so I can't check it.

@totomobile43
Copy link

totomobile43 commented May 4, 2021

I have v0.1.65 installed (Windows) and this works fine for me now. There was another similar issue when indexing into arrays like so jnp_a[0] in 64 bit mode on Windows which caused the same error as above. This is also resolved now.

@jakevdp jakevdp closed this as completed Jun 25, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants