diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index d1500f394e5e1..fe12940cc842c 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -70,6 +70,33 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None): return ivy.array(ivy.ifft2(a, s=s, dim=axes, norm=norm), dtype=ivy.dtype(a)) +@to_ivy_arrays_and_back +def ifftshift(x, axes=None): + # Check if an array + if not ivy.is_array(x): + raise ValueError("Input 'x' must be an array") + + # Get the shape of x + shape = ivy.shape(x) + + # If axes is None, shift all axes + if axes is None: + axes = list(range(len(shape))) + + # Initialize a list to store the shift values + shift_values = [] + + # Calculate shift values for each axis + for axis in axes: + axis_size = shape[axis] + shift = -ivy.floor(axis_size / 2).astype(ivy.int32) + shift_values.append(shift) + + # Perform the shift using Ivy's roll function + result = ivy.roll(x, shift_values, axes) + return result + + @to_ivy_arrays_and_back @with_unsupported_dtypes({"1.25.2 and below": ("float16", "bfloat16")}, "numpy") def rfft(a, n=None, axis=-1, norm=None): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 1030839ae8e07..b2f59abc0e2d2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -242,6 +242,30 @@ def test_jax_numpy_ifft2( ) +# ifftshift +@handle_frontend_test( + fn_tree="jax.numpy.fft.ifftshift", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=(4,), array_api_dtypes=True + ), +) +def test_jax_numpy_ifftshift( + dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device +): + input_dtype, arr = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + x=arr[0], + axes=None, + ) + + # rfft @handle_frontend_test( fn_tree="jax.numpy.fft.rfft",