From 2886229931084ced5d90b558de449ad9fc143010 Mon Sep 17 00:00:00 2001 From: rohitkg83 Date: Tue, 17 Oct 2023 11:53:37 +0100 Subject: [PATCH 1/3] Implemented jax-numpy-unwrap on latest code. --- .../jax/numpy/mathematical_functions.py | 38 +++++++++++++++- .../test_numpy/test_mathematical_functions.py | 43 +++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index b146e5743b934..28b1492972167 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -3,8 +3,9 @@ from ivy.functional.frontends.jax.func_wrapper import ( to_ivy_arrays_and_back, ) -from ivy.func_wrapper import with_unsupported_dtypes -from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs +from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes +from ivy.functional.frontends.jax.lax import slice_in_dim +from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs, concatenate from ivy.functional.frontends.numpy.manipulation_routines import trim_zeros @@ -723,6 +724,39 @@ def trunc(x): return ivy.trunc(x) +@with_supported_dtypes( + {"0.4.14 and below": ("float32", "float64", "int32", "int64")}, "jax" +) +@to_ivy_arrays_and_back +def unwrap(p, discont=None, axis=-1, period=2 * ivy.pi): + p = ivy.asarray(p) + _dtype_to_inexact = { + "int32": "float64", + "int64": "float64", + "float32": "float32", + "float64": "float64", + } + dtype = _dtype_to_inexact[p.dtype] + p = p.astype(dtype) + if discont is None: + discont = period / 2 + interval = period / 2 + dd = ivy.diff(p, axis=axis) + ddmod = ivy.remainder(dd + interval, period) - interval + ddmod = ivy.where((ddmod == -interval) & (dd > 0), interval, ddmod) + ph_correct = ivy.where(ivy.abs(dd) < discont, 0, ddmod - dd) + up = concatenate( + ( + slice_in_dim(p, 0, 1, axis=axis), + slice_in_dim(p, 1, None, axis=axis) + ivy.cumsum(ph_correct, axis=axis), + ), + axis=axis, + dtype=dtype, + ) + + return up + + @to_ivy_arrays_and_back def vdot(a, b): a, b = promote_types_of_jax_inputs(a, b) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py index 9bc5d435dce60..e85b4bc556c7a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_mathematical_functions.py @@ -3332,6 +3332,49 @@ def test_jax_trunc( ) +@handle_frontend_test( + fn_tree="jax.numpy.unwrap", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("numeric"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=10, + min_value=-ivy.pi, + max_value=ivy.pi, + valid_axis=True, + force_int_axis=True, + ), + discont=st.floats(min_value=0, max_value=3.0), + period=st.floats(min_value=2 * np.pi, max_value=10.0), + test_with_out=st.just(False), +) +def test_jax_unwrap( + *, + dtype_x_axis, + on_device, + fn_tree, + frontend, + backend_fw, + test_flags, + discont, + period, +): + dtype, x, axis = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + p=x[0], + discont=discont, + axis=axis, + period=period, + ) + + # vdot @handle_frontend_test( fn_tree="jax.numpy.vdot", From 32652001fe02a1586580ffd7c82a27f67ce1a076 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Sun, 3 Mar 2024 11:36:32 +0000 Subject: [PATCH 2/3] Update mathematical_functions.py --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 28b1492972167..4aea1e7ef5cb1 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -753,7 +753,6 @@ def unwrap(p, discont=None, axis=-1, period=2 * ivy.pi): axis=axis, dtype=dtype, ) - return up From b73a7823c16ddbb016060273a98bc44c8748b6d8 Mon Sep 17 00:00:00 2001 From: Sam-Armstrong Date: Sat, 13 Jul 2024 02:45:07 +0100 Subject: [PATCH 3/3] minor fix --- ivy/functional/frontends/jax/numpy/mathematical_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/frontends/jax/numpy/mathematical_functions.py b/ivy/functional/frontends/jax/numpy/mathematical_functions.py index 348a3cfdab503..e65900a59df0f 100644 --- a/ivy/functional/frontends/jax/numpy/mathematical_functions.py +++ b/ivy/functional/frontends/jax/numpy/mathematical_functions.py @@ -918,7 +918,7 @@ def trunc(x): @with_supported_dtypes( - {"0.4.14 and below": ("float32", "float64", "int32", "int64")}, "jax" + {"0.4.30 and below": ("float32", "float64", "int32", "int64")}, "jax" ) @to_ivy_arrays_and_back def unwrap(p, discont=None, axis=-1, period=2 * ivy.pi):