diff --git a/ivy/functional/frontends/jax/numpy/creation.py b/ivy/functional/frontends/jax/numpy/creation.py index 65de35d7444d1..0ee3781d3c64b 100644 --- a/ivy/functional/frontends/jax/numpy/creation.py +++ b/ivy/functional/frontends/jax/numpy/creation.py @@ -1,7 +1,7 @@ import ivy from ivy.func_wrapper import with_unsupported_dtypes from ivy.functional.frontends.jax.array import Array -import ivy.functional.frontends.jax.numpy as jnp +import ivy.functional.frontends.jax.numpy as jnp_frontend from ivy.functional.frontends.jax.func_wrapper import ( to_ivy_arrays_and_back, outputs_to_frontend_arrays, @@ -270,7 +270,7 @@ def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None): if ivy.is_bool_dtype(ar1.dtype) else ivy.to_scalar(ivy.min(ar1)) ) - ar1 = jnp.unique(ar1, size=size and ar1.size, fill_value=val).ivy_array + ar1 = jnp_frontend.unique(ar1, size=size and ar1.size, fill_value=val).ivy_array mask = in1d(ar1, ar2, invert=True).ivy_array if size is None: return ar1[mask] @@ -281,7 +281,7 @@ def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None): mask = ivy.where(ivy.arange(ar1.size) < n_unique, mask, False) return ivy.where( ivy.arange(size) < mask.sum(dtype=ivy.int64), - ar1[jnp.where(mask, size=size)[0].ivy_array], + ar1[jnp_frontend.where(mask, size=size)[0].ivy_array], fill_value, )