Skip to content

Commit

Permalink
feat: implemented ifftshift and the corresponding test for jax fronte…
Browse files Browse the repository at this point in the history
…nd (#28707)
  • Loading branch information
VaishnaviMudaliar authored Apr 9, 2024
1 parent decb87f commit 3221d42
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
24 changes: 24 additions & 0 deletions ivy/functional/frontends/jax/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,30 @@ def ifftn(a, s=None, axes=None, norm=None):
return a


@to_ivy_arrays_and_back
def ifftshift(x, axes=None):
if not ivy.is_array(x):
raise ValueError("Input 'x' must be an array")

# Get the shape of x
shape = ivy.shape(x)

# If axes is None, shift all axes
if axes is None:
axes = tuple(range(x.ndim))

# Convert axes to a list if it's not already
axes = [axes] if isinstance(axes, int) else list(axes)

# Perform the shift for each axis
for axis in axes:
axis_size = shape[axis]
shift = -ivy.floor(axis_size / 2).astype(ivy.int32)
result = ivy.roll(x, shift, axis=axis)

return result


@to_ivy_arrays_and_back
@with_unsupported_dtypes({"1.25.2 and below": ("float16", "bfloat16")}, "numpy")
def rfft(a, n=None, axis=-1, norm=None):
Expand Down
37 changes: 37 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,43 @@ def test_jax_numpy_ifftn(
)


# ifftshift
@handle_frontend_test(
fn_tree="jax.numpy.fft.ifftshift",
dtype_values_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("valid"),
num_arrays=1,
min_value=-1e5,
max_value=1e5,
min_num_dims=1,
max_num_dims=5,
min_dim_size=1,
max_dim_size=5,
allow_inf=False,
large_abs_safety_factor=2.5,
small_abs_safety_factor=2.5,
safety_factor_scale="log",
valid_axis=True,
force_int_axis=True,
),
)
def test_jax_numpy_ifftshift(
dtype_values_axis, backend_fw, frontend, test_flags, fn_tree, on_device
):
dtype, values, axis = dtype_values_axis
helpers.test_frontend_function(
input_dtypes=dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
test_values=True,
x=values[0],
axes=axis,
)


# rfft
@handle_frontend_test(
fn_tree="jax.numpy.fft.rfft",
Expand Down

0 comments on commit 3221d42

Please sign in to comment.