-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cross Product Error in 64bit mode #6058
Comments
This does not reproduce on OSX or Linux. I suspect this is related to signed integers defaulting to 32-bit in windows, combined with the issue in #6051. |
OK, here's the issue: for weak-typed integers, >>> x = lax.convert_element_type(1, 'int64', weak_type=True)
>>> jnp.result_type(x)
int On linux and osx, numpy integers default to 64-bit: >>> np.array(1, dtype=int).dtype
numpy.int64 On some windows machines, they default to 32-bit: >>> np.array(1, dtype=int).dtype
numpy.int32 In this case, operations like the following will fail >>> x < 1
TypeError: lt requires arguments to have the same dtypes, got int64, int32. I think the solution is essentially the fix for #6051. |
Hi @oracle3001 – with #6068 merged, I believe this problem should be fixed. Can you install jax from the master branch and check if that is the case? Thanks! |
I have v0.1.65 installed (Windows) and this works fine for me now. There was another similar issue when indexing into arrays like so jnp_a[0] in 64 bit mode on Windows which caused the same error as above. This is also resolved now. |
When I try a simple cross product calculation using Jax in 64bit on a Windows 10 machine, I get the following error.
The text was updated successfully, but these errors were encountered: