Skip to content

Commit

Permalink
feat: Implemented jax.numpy.unwrap frontend (#27050)
Browse files Browse the repository at this point in the history
Co-authored-by: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com>
Co-authored-by: Sam-Armstrong <samuel_e_armstrong@yahoo.co.uk>
  • Loading branch information
3 people committed Jul 13, 2024
1 parent 81c831b commit 7d610c0
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
37 changes: 35 additions & 2 deletions ivy/functional/frontends/jax/numpy/mathematical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from ivy.functional.frontends.jax.func_wrapper import (
to_ivy_arrays_and_back,
)
from ivy.func_wrapper import with_unsupported_dtypes
from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs
from ivy.func_wrapper import with_unsupported_dtypes, with_supported_dtypes
from ivy.functional.frontends.jax.lax import slice_in_dim
from ivy.functional.frontends.jax.numpy import promote_types_of_jax_inputs, concatenate
from ivy.functional.frontends.numpy.manipulation_routines import trim_zeros
from ivy.utils.einsum_path_helpers import (
parse_einsum_input,
Expand Down Expand Up @@ -916,6 +917,38 @@ def trunc(x):
return ivy.trunc(x)


@with_supported_dtypes(
{"0.4.30 and below": ("float32", "float64", "int32", "int64")}, "jax"
)
@to_ivy_arrays_and_back
def unwrap(p, discont=None, axis=-1, period=2 * ivy.pi):
p = ivy.asarray(p)
_dtype_to_inexact = {
"int32": "float64",
"int64": "float64",
"float32": "float32",
"float64": "float64",
}
dtype = _dtype_to_inexact[p.dtype]
p = p.astype(dtype)
if discont is None:
discont = period / 2
interval = period / 2
dd = ivy.diff(p, axis=axis)
ddmod = ivy.remainder(dd + interval, period) - interval
ddmod = ivy.where((ddmod == -interval) & (dd > 0), interval, ddmod)
ph_correct = ivy.where(ivy.abs(dd) < discont, 0, ddmod - dd)
up = concatenate(
(
slice_in_dim(p, 0, 1, axis=axis),
slice_in_dim(p, 1, None, axis=axis) + ivy.cumsum(ph_correct, axis=axis),
),
axis=axis,
dtype=dtype,
)
return up


@to_ivy_arrays_and_back
def vdot(a, b):
a, b = promote_types_of_jax_inputs(a, b)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3379,6 +3379,49 @@ def test_jax_trunc(
)


@handle_frontend_test(
fn_tree="jax.numpy.unwrap",
dtype_x_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("numeric"),
min_num_dims=2,
max_num_dims=5,
min_dim_size=2,
max_dim_size=10,
min_value=-ivy.pi,
max_value=ivy.pi,
valid_axis=True,
force_int_axis=True,
),
discont=st.floats(min_value=0, max_value=3.0),
period=st.floats(min_value=2 * np.pi, max_value=10.0),
test_with_out=st.just(False),
)
def test_jax_unwrap(
*,
dtype_x_axis,
on_device,
fn_tree,
frontend,
backend_fw,
test_flags,
discont,
period,
):
dtype, x, axis = dtype_x_axis
helpers.test_frontend_function(
input_dtypes=dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
p=x[0],
discont=discont,
axis=axis,
period=period,
)


# vdot
@handle_frontend_test(
fn_tree="jax.numpy.vdot",
Expand Down

0 comments on commit 7d610c0

Please sign in to comment.