From 4f05a09ff892de70d72e4facfc69496830d5237a Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Sat, 30 Mar 2024 06:11:59 +0000 Subject: [PATCH 1/4] added ifftshift and corresponding test --- ivy/functional/frontends/jax/numpy/fft.py | 27 ++++++++++++ .../test_jax/test_numpy/test_fft.py | 42 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index 125696d97faa3..c2e3a0cf9b2c7 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -78,6 +78,33 @@ def ifftn(a, s=None, axes=None, norm=None): return a +@to_ivy_arrays_and_back +def ifftshift(x, axes=None): + # Check if an array + if not ivy.is_array(x): + raise ValueError("Input 'x' must be an array") + + # Get the shape of x + shape = ivy.shape(x) + + # If axes is None, shift all axes + if axes is None: + axes = list(range(len(shape))) + + # Initialize a list to store the shift values + shift_values = [] + + # Calculate shift values for each axis + for axis in axes: + axis_size = shape[axis] + shift = -ivy.floor(axis_size / 2).astype(ivy.int32) + shift_values.append(shift) + + # Perform the shift using Ivy's roll function + result = ivy.roll(x, shift_values, axes) + return result + + @to_ivy_arrays_and_back @with_unsupported_dtypes({"1.25.2 and below": ("float16", "bfloat16")}, "numpy") def rfft(a, n=None, axis=-1, norm=None): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 09925136e06aa..1e8953e177e47 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -271,6 +271,48 @@ def test_jax_numpy_ifftn( ) +# ifftshift +@handle_frontend_test( + fn_tree="jax.numpy.fft.ifftshift", + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("complex"), + num_arrays=1, + min_value=-1e5, + max_value=1e5, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + valid_axis=True, + force_int_axis=True, + ), + n=st.integers(min_value=2, max_value=10), + norm=st.sampled_from(["backward", "ortho", "forward", None]), +) +def test_jax_numpy_ifftshift( + dtype_values_axis, n, norm, frontend, backend_fw, test_flags, fn_tree, on_device +): + dtype, values, axis = dtype_values_axis + helpers.test_frontend_function( + input_dtypes=dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=values[0], + n=n, + axis=axis, + norm=norm, + atol=1e-02, + rtol=1e-02, + ) + + # rfft @handle_frontend_test( fn_tree="jax.numpy.fft.rfft", From 7e4b6ae3db86fc7623712eb24763c03f1eecee68 Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Sat, 30 Mar 2024 11:27:17 +0000 Subject: [PATCH 2/4] updated test function --- ivy/functional/frontends/jax/numpy/fft.py | 2 +- .../test_jax/test_numpy/test_fft.py | 34 +++++-------------- 2 files changed, 9 insertions(+), 27 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index c2e3a0cf9b2c7..a94dbf580b2c1 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -89,7 +89,7 @@ def ifftshift(x, axes=None): # If axes is None, shift all axes if axes is None: - axes = list(range(len(shape))) + axes = tuple(range(x.ndim)) # Initialize a list to store the shift values shift_values = [] diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 1e8953e177e47..231c7d5e91ab5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -274,42 +274,24 @@ def test_jax_numpy_ifftn( # ifftshift @handle_frontend_test( fn_tree="jax.numpy.fft.ifftshift", - dtype_values_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("complex"), - num_arrays=1, - min_value=-1e5, - max_value=1e5, - min_num_dims=1, - max_num_dims=5, - min_dim_size=1, - max_dim_size=5, - allow_inf=False, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - valid_axis=True, - force_int_axis=True, + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=(4,), array_api_dtypes=True ), - n=st.integers(min_value=2, max_value=10), - norm=st.sampled_from(["backward", "ortho", "forward", None]), ) def test_jax_numpy_ifftshift( - dtype_values_axis, n, norm, frontend, backend_fw, test_flags, fn_tree, on_device + dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device ): - dtype, values, axis = dtype_values_axis + input_dtype, arr = dtype_and_x helpers.test_frontend_function( - input_dtypes=dtype, + input_dtypes=input_dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - a=values[0], - n=n, - axis=axis, - norm=norm, - atol=1e-02, - rtol=1e-02, + test_values=True, + x=arr[0], + axes=None, ) From 858bc8a71a882a87cad70aba7936cdc2331298a7 Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Wed, 3 Apr 2024 11:25:29 +0000 Subject: [PATCH 3/4] updated fft.py and test_fft.py --- ivy/functional/frontends/jax/numpy/fft.py | 4 ++-- .../test_ivy/test_frontends/test_jax/test_numpy/test_fft.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index a94dbf580b2c1..99a1165f5c82d 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -89,7 +89,7 @@ def ifftshift(x, axes=None): # If axes is None, shift all axes if axes is None: - axes = tuple(range(x.ndim)) + axes = list(range(len(shape))) # Initialize a list to store the shift values shift_values = [] @@ -101,7 +101,7 @@ def ifftshift(x, axes=None): shift_values.append(shift) # Perform the shift using Ivy's roll function - result = ivy.roll(x, shift_values, axes) + result = ivy.roll(x, shift_values, axis=axes) return result diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 231c7d5e91ab5..bbf8bd6f1bbd6 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -291,7 +291,7 @@ def test_jax_numpy_ifftshift( on_device=on_device, test_values=True, x=arr[0], - axes=None, + axes=None, # You can change this to test specific axes if needed ) From 1c9df345a637901a9c770a283b9f0d5ee72563b7 Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Tue, 9 Apr 2024 08:37:12 +0000 Subject: [PATCH 4/4] updated test function to include multiple axes --- ivy/functional/frontends/jax/numpy/fft.py | 15 +++++------ .../test_jax/test_numpy/test_fft.py | 27 ++++++++++++++----- 2 files changed, 26 insertions(+), 16 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index 99a1165f5c82d..963912e380d5f 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -80,28 +80,25 @@ def ifftn(a, s=None, axes=None, norm=None): @to_ivy_arrays_and_back def ifftshift(x, axes=None): - # Check if an array if not ivy.is_array(x): - raise ValueError("Input 'x' must be an array") + raise ValueError("Input 'x' must be an array") # Get the shape of x shape = ivy.shape(x) # If axes is None, shift all axes if axes is None: - axes = list(range(len(shape))) + axes = tuple(range(x.ndim)) - # Initialize a list to store the shift values - shift_values = [] + # Convert axes to a list if it's not already + axes = [axes] if isinstance(axes, int) else list(axes) - # Calculate shift values for each axis + # Perform the shift for each axis for axis in axes: axis_size = shape[axis] shift = -ivy.floor(axis_size / 2).astype(ivy.int32) - shift_values.append(shift) + result = ivy.roll(x, shift, axis=axis) - # Perform the shift using Ivy's roll function - result = ivy.roll(x, shift_values, axis=axes) return result diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index bbf8bd6f1bbd6..3993e06fb109f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -274,24 +274,37 @@ def test_jax_numpy_ifftn( # ifftshift @handle_frontend_test( fn_tree="jax.numpy.fft.ifftshift", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), shape=(4,), array_api_dtypes=True + dtype_values_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + num_arrays=1, + min_value=-1e5, + max_value=1e5, + min_num_dims=1, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + valid_axis=True, + force_int_axis=True, ), ) def test_jax_numpy_ifftshift( - dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device + dtype_values_axis, backend_fw, frontend, test_flags, fn_tree, on_device ): - input_dtype, arr = dtype_and_x + dtype, values, axis = dtype_values_axis helpers.test_frontend_function( - input_dtypes=input_dtype, + input_dtypes=dtype, frontend=frontend, backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, test_values=True, - x=arr[0], - axes=None, # You can change this to test specific axes if needed + x=values[0], + axes=axis, )