Skip to content

Commit

Permalink
update jnp_frontend import
Browse files Browse the repository at this point in the history
  • Loading branch information
vismaysur committed Feb 6, 2024
1 parent 5a38ecf commit 3636f44
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions ivy/functional/frontends/jax/numpy/creation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ivy
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.functional.frontends.jax.array import Array
import ivy.functional.frontends.jax.numpy as jnp
import ivy.functional.frontends.jax.numpy as jnp_frontend
from ivy.functional.frontends.jax.func_wrapper import (
to_ivy_arrays_and_back,
outputs_to_frontend_arrays,
Expand Down Expand Up @@ -270,7 +270,7 @@ def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
if ivy.is_bool_dtype(ar1.dtype)
else ivy.to_scalar(ivy.min(ar1))
)
ar1 = jnp.unique(ar1, size=size and ar1.size, fill_value=val).ivy_array
ar1 = jnp_frontend.unique(ar1, size=size and ar1.size, fill_value=val).ivy_array
mask = in1d(ar1, ar2, invert=True).ivy_array
if size is None:
return ar1[mask]
Expand All @@ -281,7 +281,7 @@ def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None):
mask = ivy.where(ivy.arange(ar1.size) < n_unique, mask, False)
return ivy.where(
ivy.arange(size) < mask.sum(dtype=ivy.int64),
ar1[jnp.where(mask, size=size)[0].ivy_array],
ar1[jnp_frontend.where(mask, size=size)[0].ivy_array],
fill_value,
)

Expand Down

0 comments on commit 3636f44

Please sign in to comment.