fix convert_element_type on large Python int inputs #6165
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
@jekbradbury noticed that when we call
lax.convert_element_type(2 ** 100, jnp.float32)
, we first convert to int32 when we canonicalize the input dtype. But int32 can't represent 2**100, while a float32 can!Concretely, before this PR, this fails:
It fails because under the hood it ends up doing this, which also fails:
even though this succeeds:
#6014 caused this to surface in a downstream library, because the special-case logic #6014 removed had effectively called
np.array(x, to_dtype)
before applying any JAX primitives.Luckily these issues lead to loud overflow errors from NumPy, rather than a silent loss of bits.
The fix is just to have
lax.convert_element_type
(i.e. the 'traceable' wrapper) use NumPy to convert Python int inputs to numpy arrays with the target dtype (like a float32), rather than the current behavior of converting to the canonical dtype for the input (like an int32), before transferring the value out of Python and to the device.I also had to fix up some handling of float0s in host_callback logic. The logic now mirrors the analogous logic inI tweaked the implementation of the fix so that these aren't needed anymore. They may still be good changes, but I'd rather keep the PR minimal.custom_jvp
/custom_vjp
.