Skip to content

Commit

Permalink
Temporarily disable integer index check in jnp.take_along_axis.
Browse files Browse the repository at this point in the history
This check broke some JAX users; disable it to give time to fix them.

PiperOrigin-RevId: 441993808
  • Loading branch information
hawkinsp authored and jax authors committed Apr 15, 2022
1 parent 375777f commit 0c1021a
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3425,10 +3425,11 @@ def _normalize_index(index, axis_size):
@partial(jit, static_argnames=('axis',))
def take_along_axis(arr, indices, axis: Optional[int]):
_check_arraylike("take_along_axis", arr, indices)
index_dtype = dtypes.dtype(indices)
if not dtypes.issubdtype(index_dtype, integer):
raise TypeError("take_along_axis indices must be of integer type, got "
f"{str(index_dtype)}")
# index_dtype = dtypes.dtype(indices)
# TODO(phawkins): reenalbe this check after fixing callers
# if not dtypes.issubdtype(index_dtype, integer):
# raise TypeError("take_along_axis indices must be of integer type, got "
# f"{str(index_dtype)}")
if axis is None:
if ndim(indices) != 1:
msg = "take_along_axis indices must be 1D if axis=None, got shape {}"
Expand Down

0 comments on commit 0c1021a

Please sign in to comment.