diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index 0ba0fa31f9c94..8b2d0e17aebf6 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -41,3 +41,10 @@ def ifft(a, n=None, axis=-1, norm=None): if norm is None: norm = "backward" return ivy.ifft(a, axis, norm=norm, n=n) + + +@to_ivy_arrays_and_back +def ifft2(a, s=None, axes=(-2, -1), norm=None): + if norm is None: + norm = "backward" + return ivy.array(ivy.ifft2(a, s=s, dim=axes, norm=norm), dtype=ivy.dtype(a)) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index f4fc0911ae534..a13ea4f44bb6d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -163,3 +163,54 @@ def test_jax_numpy_ifft( atol=1e-02, rtol=1e-02, ) + + +# ifft2 +@handle_frontend_test( + fn_tree="jax.numpy.fft.ifft2", + dtype_values=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + min_value=-1e5, + max_value=1e5, + min_num_dims=2, + max_num_dims=5, + min_dim_size=2, + max_dim_size=5, + allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + ), + axes=st.sampled_from([(0, 1), (-1, -2), (1, 0)]), + s=st.tuples( + st.integers(min_value=2, max_value=256), st.integers(min_value=2, max_value=256) + ), + norm=st.sampled_from(["backward", "ortho", "forward", None]), +) +def test_jax_numpy_ifft2( + dtype_values, + s, + axes, + norm, + frontend, + backend_fw, + test_flags, + fn_tree, + on_device, +): + dtype, values = dtype_values + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=values[0], + s=s, + axes=axes, + norm=norm, + atol=1e-02, + rtol=1e-02, + )