diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index 125696d97faa3..963912e380d5f 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -78,6 +78,30 @@ def ifftn(a, s=None, axes=None, norm=None): return a +@to_ivy_arrays_and_back +def ifftshift(x, axes=None): + if not ivy.is_array(x): + raise ValueError("Input 'x' must be an array") + + # Get the shape of x + shape = ivy.shape(x) + + # If axes is None, shift all axes + if axes is None: + axes = tuple(range(x.ndim)) + + # Convert axes to a list if it's not already + axes = [axes] if isinstance(axes, int) else list(axes) + + # Perform the shift for each axis + for axis in axes: + axis_size = shape[axis] + shift = -ivy.floor(axis_size / 2).astype(ivy.int32) + result = ivy.roll(x, shift, axis=axis) + + return result + + @to_ivy_arrays_and_back @with_unsupported_dtypes({"1.25.2 and below": ("float16", "bfloat16")}, "numpy") def rfft(a, n=None, axis=-1, norm=None): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 09925136e06aa..3993e06fb109f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -271,6 +271,43 @@ def test_jax_numpy_ifftn( ) +# ifftshift +@handle_frontend_test( + fn_tree="jax.numpy.fft.ifftshift", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + min_value=-1e5, + max_value=1e5, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + valid_axis=True, + force_int_axis=True, + ), +) +def test_jax_numpy_ifftshift( + dtype_values_axis, backend_fw, frontend, test_flags, fn_tree, on_device +): + dtype, values, axis = dtype_values_axis + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + x=values[0], + axes=axis, + ) + + # rfft @handle_frontend_test( fn_tree="jax.numpy.fft.rfft",