Skip to content

Commit

Permalink
[JAX] Don't pass non-integer indices to jnp.take_along_axis
Browse files Browse the repository at this point in the history
Previously JAX casted non-integer indices to integers, but in the future it will issue an error (as np.take_along_axis does). This change adds an explicit integer cast to callers that were passing non-integer values.

PiperOrigin-RevId: 442062233
  • Loading branch information
hawkinsp authored and jax authors committed Apr 15, 2022
1 parent 7ffdac0 commit 6eec758
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,11 +1660,11 @@ def _make_harness(group_name: str, name: str,
poly_axes=[0, None], enable_and_diable_xla=True),
_make_harness("take_along_axis", "0",
lambda x, y: jnp.take_along_axis(x, y, axis=0),
[RandArg((5, 2), _f32), RandArg((5, 1), _f32)],
[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)],
poly_axes=[0, 0]),
_make_harness("take_along_axis", "1",
lambda x, y: jnp.take_along_axis(x, y, axis=1),
[RandArg((5, 2), _f32), RandArg((5, 1), _f32)],
[RandArg((5, 2), _f32), RandArg((5, 1), np.int32)],
poly_axes=[0, 0]),
_make_harness("tile", "0",
lambda x: jnp.tile(x, (1, 2)),
Expand Down

0 comments on commit 6eec758

Please sign in to comment.