From 0c1021ad4b8200e946275e0c97a45d64a8f7f208 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 15 Apr 2022 05:44:49 -0700 Subject: [PATCH] Temporarily disable integer index check in jnp.take_along_axis. This check broke some JAX users; disable it to give time to fix them. PiperOrigin-RevId: 441993808 --- jax/_src/numpy/lax_numpy.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 1e14f9592bc5..bfdcc6ce4af7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3425,10 +3425,11 @@ def _normalize_index(index, axis_size): @partial(jit, static_argnames=('axis',)) def take_along_axis(arr, indices, axis: Optional[int]): _check_arraylike("take_along_axis", arr, indices) - index_dtype = dtypes.dtype(indices) - if not dtypes.issubdtype(index_dtype, integer): - raise TypeError("take_along_axis indices must be of integer type, got " - f"{str(index_dtype)}") + # index_dtype = dtypes.dtype(indices) + # TODO(phawkins): reenalbe this check after fixing callers + # if not dtypes.issubdtype(index_dtype, integer): + # raise TypeError("take_along_axis indices must be of integer type, got " + # f"{str(index_dtype)}") if axis is None: if ndim(indices) != 1: msg = "take_along_axis indices must be 1D if axis=None, got shape {}"