diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index d8cb3c57430f..6595ef99369e 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1660,11 +1660,11 @@ def _make_harness(group_name: str, name: str, poly_axes=[0, None], enable_and_diable_xla=True), _make_harness("take_along_axis", "0", lambda x, y: jnp.take_along_axis(x, y, axis=0), - [RandArg((5, 2), _f32), RandArg((5, 1), _f32)], + [RandArg((5, 2), _f32), RandArg((5, 1), np.int32)], poly_axes=[0, 0]), _make_harness("take_along_axis", "1", lambda x, y: jnp.take_along_axis(x, y, axis=1), - [RandArg((5, 2), _f32), RandArg((5, 1), _f32)], + [RandArg((5, 2), _f32), RandArg((5, 1), np.int32)], poly_axes=[0, 0]), _make_harness("tile", "0", lambda x: jnp.tile(x, (1, 2)),