You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The JAX array API I'm developing in jax-ml/jax#16099 hits a number of failures due to pytest calling APIs with integers larger than the maximum int64. For example:
In [1]: importjaxIn [2]: jax.config.update('jax_enable_x64', True)
In [3]: val=2**63# one of the values generated by hypothesisIn [4]: jax.numpy.less(val, 0)
---------------------------------------------------------------------------OverflowError: Anoverflowwasencounteredwhileparsinganargumenttoajittedcomputation, whoseargumentpathisx1.
By contrast, this isn't a problem in numpy:
In [5]: importnumpyasnpIn [6]: np.less(val, 0)
False
The reason for this discrepancy is that numpy does value-dependent casting of Python ints:
In [7]: np.array(val-1).dtypeOut[7]: dtype('int64')
In [8]: np.array(val).dtypeOut[8]: dtype('uint64')
JAX has made the deliberate decision to avoid these kinds of implicit value-dependent semantics, and raises an error in the second case:
In [9]: jax.numpy.array(val-1).dtypeOut[9]: dtype('int64')
In [10]: jax.numpy.array(val).dtype---------------------------------------------------------------------------OverflowError: Pythonint9223372036854775808toolargetoconverttoint64
This design decision results in the array API test failures mentioned above.
I would like to address this in the JAX array API branch, so my question is this: Is this value-dependent integer casting behavior part of the Array API specification?
If the answer is yes, then I can add functionality to the jax array api wrappers to handle these corner cases.
If the answer is no, then the fact that the test suite depends on such behavior should probably be considered a bug.
Do you have thoughts on how I should proceed? Thanks!
The text was updated successfully, but these errors were encountered:
The JAX array API I'm developing in jax-ml/jax#16099 hits a number of failures due to pytest calling APIs with integers larger than the maximum
int64
. For example:By contrast, this isn't a problem in numpy:
The reason for this discrepancy is that numpy does value-dependent casting of Python ints:
JAX has made the deliberate decision to avoid these kinds of implicit value-dependent semantics, and raises an error in the second case:
This design decision results in the array API test failures mentioned above.
I would like to address this in the JAX array API branch, so my question is this: Is this value-dependent integer casting behavior part of the Array API specification?
Do you have thoughts on how I should proceed? Thanks!
The text was updated successfully, but these errors were encountered: