diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index dd637eeea501..e65900a59df0 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -3,8 +3,9 @@ from ivy.functional.frontends.jax.func_wrapper import ( to_ivy_arrays_and_back, ) -from ivy.func_wrapper import with_unsupported_dtypes -from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs +from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes +from ivy.functional.frontends.jax.lax import slice_in_dim +from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs, concatenate from ivy.functional.frontends.numpy.manipulation_routines import trim_zeros from ivy.utils.einsum_path_helpers import ( parse_einsum_input, @@ -916,6 +917,38 @@ def trunc(x): return ivy.trunc(x) +@with_supported_dtypes( + {"0.4.30 and below": ("float32", "float64", "int32", "int64")}, "jax" +) +@to_ivy_arrays_and_back +def unwrap(p, discont=None, axis=-1, period=2 * ivy.pi): + p = ivy.asarray(p) + _dtype_to_inexact = { + "int32": "float64", + "int64": "float64", + "float32": "float32", + "float64": "float64", + } + dtype = _dtype_to_inexact[p.dtype] + p = p.astype(dtype) + if discont is None: + discont = period / 2 + interval = period / 2 + dd = ivy.diff(p, axis=axis) + ddmod = ivy.remainder(dd + interval, period) - interval + ddmod = ivy.where((ddmod == -interval) & (dd > 0), interval, ddmod) + ph_correct = ivy.where(ivy.abs(dd) < discont, 0, ddmod - dd) + up = concatenate( + ( + slice_in_dim(p, 0, 1, axis=axis), + slice_in_dim(p, 1, None, axis=axis) + ivy.cumsum(ph_correct, axis=axis), + ), + axis=axis, + dtype=dtype, + ) + return up + + @to_ivy_arrays_and_back def vdot(a, b): a, b = promote_types_of_jax_inputs(a, b) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index bf668346c1a8..71edafa650ca 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -3379,6 +3379,49 @@ def test_jax_trunc( ) +@handle_frontend_test( + fn_tree="jax.numpy.unwrap", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + min_value=-ivy.pi, + max_value=ivy.pi, + valid_axis=True, + force_int_axis=True, + ), + discont=st.floats(min_value=0, max_value=3.0), + period=st.floats(min_value=2 * np.pi, max_value=10.0), + test_with_out=st.just(False), +) +def test_jax_unwrap( + *, + dtype_x_axis, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, + discont, + period, +): + dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + p=x[0], + discont=discont, + axis=axis, + period=period, + ) + + # vdot @handle_frontend_test( fn_tree="jax.numpy.vdot",