-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
np.cint/np.int32 type confusion #4903
Comments
OK, |
Here is another issue : import numpy as np
import jax.numpy as jnp
xs = np.array([1000000000000, 2, 3, 4])
jxs = jnp.array([1000000000000, 2, 3, 4])
for x in xs:
print(x, x.dtype, type(x))
for x in jxs:
print(x, x.dtype, type(x)) On Linux output:
On Windows output:
jax iterator outputs platform dependent element type I also observe |
I think this is an issue that other folks have hit before. There's a difference between "native" and "standard" sizes in NumPy that manifests only in Windows: My guess is we need to change some of the format descriptors returned here: to have |
This does not specific to windows. I updated my previous comment. |
…tead of platform-specific types. Will fix jax-ml/jax#4903 when incorporated into a jaxlib. PiperOrigin-RevId: 342646750 Change-Id: I309f5e9762b1ef7ca049439705a84b4cb1c799d9
Outputs:
The text was updated successfully, but these errors were encountered: