Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

np.cint/np.int32 type confusion #4903

Closed
cloudhan opened this issue Nov 16, 2020 · 4 comments · Fixed by #4910
Closed

np.cint/np.int32 type confusion #4903

cloudhan opened this issue Nov 16, 2020 · 4 comments · Fixed by #4910
Assignees

Comments

@cloudhan
Copy link
Contributor

import numpy as np
import jax.numpy as jnp

xs = np.array([1000000000000, 2, 3, 4])
jxs = jnp.array([1000000000000, 2, 3, 4])

print(xs)
print(jxs)

Outputs:

[1000000000000             2             3             4]
[-727379968          2          3          4]
@cloudhan
Copy link
Contributor Author

OK, export JAX_ENABLE_X64=1 solves this, but no warning is raised.

@cloudhan
Copy link
Contributor Author

cloudhan commented Nov 16, 2020

Here is another issue :

import numpy as np
import jax.numpy as jnp

xs = np.array([1000000000000, 2, 3, 4])
jxs = jnp.array([1000000000000, 2, 3, 4])

for x in xs:
    print(x, x.dtype, type(x))

for x in jxs:
    print(x, x.dtype, type(x))

On Linux output:

1000000000000 int64 <class 'numpy.int64'>
2 int64 <class 'numpy.int64'>
3 int64 <class 'numpy.int64'>
4 int64 <class 'numpy.int64'>
1000000000000 int64 <class 'numpy.longlong'>
2 int64 <class 'numpy.longlong'>
3 int64 <class 'numpy.longlong'>
4 int64 <class 'numpy.longlong'>

On Windows output:

1000000000000 int64 <class 'numpy.int64'>
2 int64 <class 'numpy.int64'>
3 int64 <class 'numpy.int64'>
4 int64 <class 'numpy.int64'>
1000000000000 int64 <class 'numpy.int64'>
2 int64 <class 'numpy.int64'>
3 int64 <class 'numpy.int64'>
4 int64 <class 'numpy.int64'>

jax iterator outputs platform dependent element type
In this case, int64 became numpy.longlong on Linux but not on Windows.

I also observe int32 became numpy.cint on Windows, but not on Linux.

@hawkinsp
Copy link
Collaborator

I think this is an issue that other folks have hit before. There's a difference between "native" and "standard" sizes in NumPy that manifests only in Windows:
https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment
pybind/pybind11#1806
pybind/pybind11#1908

My guess is we need to change some of the format descriptors returned here:
https://github.com/tensorflow/tensorflow/blob/516f4a121c5c1f9bb712b774109c6ad5283d27ad/tensorflow/compiler/xla/python/types.cc#L104

to have = prefixes.

@hawkinsp hawkinsp changed the title Wrong type captured np.cint/np.int32 type confusion on Windows Nov 16, 2020
@cloudhan
Copy link
Contributor Author

cloudhan commented Nov 16, 2020

This does not specific to windows. I updated my previous comment.

@hawkinsp hawkinsp changed the title np.cint/np.int32 type confusion on Windows np.cint/np.int32 type confusion Nov 16, 2020
@hawkinsp hawkinsp self-assigned this Nov 16, 2020
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this issue Nov 16, 2020
…tead of platform-specific types.

Will fix jax-ml/jax#4903 when incorporated into a jaxlib.

PiperOrigin-RevId: 342646750
Change-Id: I309f5e9762b1ef7ca049439705a84b4cb1c799d9
@hawkinsp hawkinsp mentioned this issue Nov 16, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants