diff --git a/docs/jax.numpy.rst b/docs/jax.numpy.rst index b34695c82790..3b7cc316fd6d 100644 --- a/docs/jax.numpy.rst +++ b/docs/jax.numpy.rst @@ -120,6 +120,7 @@ Not every function in NumPy is implemented; contributions are welcome! cumsum deg2rad degrees + delete diag diagflat diag_indices diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6c8fdcb8486a..b299c1a2bf12 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3606,6 +3606,50 @@ def append(arr, values, axis: Optional[int] = None): return concatenate([arr, values], axis=axis) +@_wraps(np.delete) +def delete(arr, obj, axis=None): + _check_arraylike("delete", arr) + if axis is None: + arr = ravel(arr) + axis = 0 + axis = _canonicalize_axis(axis, arr.ndim) + + # Case 1: obj is a static integer. + try: + obj = operator.index(obj) + obj = _canonicalize_axis(obj, arr.shape[axis]) + except TypeError: + pass + else: + idx = tuple(slice(None) for i in range(axis)) + return concatenate([arr[idx + (slice(0, obj),)], arr[idx + (slice(obj + 1, None),)]], axis=axis) + + # Case 2: obj is a static slice. + if isinstance(obj, slice): + # TODO(jakevdp): we should be able to do this dynamically with care. + indices = np.delete(np.arange(arr.shape[axis]), obj) + return take(arr, indices, axis=axis) + + # Case 3: obj is an array + # NB: pass both arrays to check for appropriate error message. + _check_arraylike("delete", arr, obj) + obj = core.concrete_or_error(np.asarray, obj, "'obj' array argument of jnp.delete()") + + if issubdtype(obj.dtype, integer): + # TODO(jakevdp): in theory this could be done dynamically if obj has no duplicates, + # but this would require the complement of lax.gather. + mask = np.ones(arr.shape[axis], dtype=bool) + mask[obj] = False + elif obj.dtype == bool: + if obj.shape != (arr.shape[axis],): + raise ValueError("np.delete(arr, obj): for boolean indices, obj must be one-dimensional " + "with length matching specified axis.") + mask = ~obj + else: + raise ValueError(f"np.delete(arr, obj): got obj.dtype={obj.dtype}; must be integer or bool.") + return arr[tuple(slice(None) for i in range(axis)) + (mask,)] + + @_wraps(np.apply_along_axis) def apply_along_axis(func1d, axis: int, arr, *args, **kwargs): num_dims = ndim(arr) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 4aec72f4a2b4..26065a04dc51 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -29,8 +29,8 @@ broadcast_to, can_cast, cbrt, cdouble, ceil, character, choose, clip, column_stack, complex128, complex64, complex_, complexfloating, compress, concatenate, conj, conjugate, convolve, copysign, corrcoef, correlate, cos, cosh, - count_nonzero, cov, cross, csingle, cumprod, cumproduct, cumsum, deg2rad, - degrees, diag, diagflat, diag_indices, diag_indices_from, diagonal, diff, digitize, divide, divmod, dot, + count_nonzero, cov, cross, csingle, cumprod, cumproduct, cumsum, deg2rad, degrees, + delete, diag, diagflat, diag_indices, diag_indices_from, diagonal, diff, digitize, divide, divmod, dot, double, dsplit, dstack, dtype, e, ediff1d, einsum, einsum_path, empty, empty_like, equal, euler_gamma, exp, exp2, expand_dims, expm1, extract, eye, fabs, finfo, fix, flatnonzero, flexible, flip, fliplr, flipud, float16, float32, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 6f48fdd55624..d7b54e0664ea 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -1960,6 +1960,80 @@ def args_maker(): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_idx={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, idx), + "dtype": dtype, "shape": shape, "axis": axis, "idx": idx} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))) + for idx in (range(-prod(shape), prod(shape)) + if axis is None else + range(-shape[axis], shape[axis])))) + def testDeleteInteger(self, shape, dtype, idx, axis): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, idx, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_slc={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, slc), + "dtype": dtype, "shape": shape, "axis": axis, "slc": slc} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))) + for slc in [slice(None), slice(1, 3), slice(1, 5, 2)])) + def testDeleteSlice(self, shape, dtype, axis, slc): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, slc, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, slc, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}_idx={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis, + jtu.format_shape_dtype_string(idx_shape, int)), + "dtype": dtype, "shape": shape, "axis": axis, "idx_shape": idx_shape} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))) + for idx_shape in all_shapes)) + def testDeleteIndexArray(self, shape, dtype, axis, idx_shape): + rng = jtu.rand_default(self.rng()) + max_idx = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + # Previous to numpy 1.19, negative indices were ignored so we don't test this. + low = 0 if numpy_version < (1, 19, 0) else -max_idx + idx = jtu.rand_int(self.rng(), low=low, high=max_idx)(idx_shape, int) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, idx, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, idx, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @unittest.skipIf(numpy_version < (1, 19), "boolean mask not supported in numpy < 1.19.0") + @parameterized.named_parameters(jtu.cases_from_list( + {"testcase_name": "_{}_axis={}".format( + jtu.format_shape_dtype_string(shape, dtype), axis), + "dtype": dtype, "shape": shape, "axis": axis} + for shape in nonempty_nonscalar_array_shapes + for dtype in all_dtypes + for axis in [None] + list(range(-len(shape), len(shape))))) + def testDeleteMaskArray(self, shape, dtype, axis): + rng = jtu.rand_default(self.rng()) + mask_size = np.zeros(shape).size if axis is None else np.zeros(shape).shape[axis] + mask = jtu.rand_int(self.rng(), low=0, high=2)(mask_size, bool) + args_maker = lambda: [rng(shape, dtype)] + np_fun = lambda arg: np.delete(arg, mask, axis=axis) + jnp_fun = lambda arg: jnp.delete(arg, mask, axis=axis) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_{}_axis={}_out_dims={}".format( jtu.format_shape_dtype_string(shape, dtype),