From 3636f44a034d66ce6e579831f77293e3bbdb9f39 Mon Sep 17 00:00:00 2001 From: Vismay Suramwar <83938053+Vismay-dev@users.noreply.github.com> Date: Tue, 6 Feb 2024 06:39:43 +0000 Subject: [PATCH] update jnp_frontend import --- ivy/functional/frontends/jax/numpy/creation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/creation.py b/ivy/functional/frontends/jax/numpy/creation.py index 65de35d7444d1..0ee3781d3c64b 100644 --- a/ivy/functional/frontends/jax/numpy/creation.py +++ b/ivy/functional/frontends/jax/numpy/creation.py @@ -1,7 +1,7 @@ import ivy from ivy.func_wrapper import with_unsupported_dtypes from ivy.functional.frontends.jax.array import Array -import ivy.functional.frontends.jax.numpy as jnp +import ivy.functional.frontends.jax.numpy as jnp_frontend from ivy.functional.frontends.jax.func_wrapper import ( to_ivy_arrays_and_back, outputs_to_frontend_arrays, @@ -270,7 +270,7 @@ def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None): if ivy.is_bool_dtype(ar1.dtype) else ivy.to_scalar(ivy.min(ar1)) ) - ar1 = jnp.unique(ar1, size=size and ar1.size, fill_value=val).ivy_array + ar1 = jnp_frontend.unique(ar1, size=size and ar1.size, fill_value=val).ivy_array mask = in1d(ar1, ar2, invert=True).ivy_array if size is None: return ar1[mask] @@ -281,7 +281,7 @@ def setdiff1d(ar1, ar2, assume_unique=False, *, size=None, fill_value=None): mask = ivy.where(ivy.arange(ar1.size) < n_unique, mask, False) return ivy.where( ivy.arange(size) < mask.sum(dtype=ivy.int64), - ar1[jnp.where(mask, size=size)[0].ivy_array], + ar1[jnp_frontend.where(mask, size=size)[0].ivy_array], fill_value, )