Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[JAX] Don't pass non-integer indices to jnp.take_along_axis
Previously JAX casted non-integer indices to integers, but in the future it will issue an error (as np.take_along_axis does). This change adds an explicit integer cast to callers that were passing non-integer values. PiperOrigin-RevId: 442062233
- Loading branch information