From 49b827147868e330738da1844ae9b6f1bfe63934 Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Mon, 18 Sep 2023 10:17:16 +0000 Subject: [PATCH 01/12] added ifftshift and test --- ivy/functional/frontends/jax/numpy/fft.py | 28 +++++++++++++++++++ .../test_jax/test_numpy/test_fft.py | 25 +++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index 8b2d0e17aebf6..30130efcb9baf 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -48,3 +48,31 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None): if norm is None: norm = "backward" return ivy.array(ivy.ifft2(a, s=s, dim=axes, norm=norm), dtype=ivy.dtype(a)) + +@to_ivy_arrays_and_back + +def ifftshift(x, axes=None): + # Check if an array + if not ivy.is_array(x): + raise ValueError("Input 'x' must be an array") + + # Get the shape of x + shape = ivy.shape(x) + + # If axes is None, shift all axes + if axes is None: + axes = list(range(ivy.ndims(x))) + + # Initialize a list to store the shift values + shift_values = [] + + # Calculate shift values for each axis + for axis in axes: + axis_size = shape[axis] + shift = -ivy.floor(axis_size / 2).astype(ivy.int32) + shift_values.append(shift) + + # Perform the shift using Ivy's roll function + result = ivy.roll(x, shift_values, axes) + return result + diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index a13ea4f44bb6d..5649a7605fc4f 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -214,3 +214,28 @@ def test_jax_numpy_ifft2( atol=1e-02, rtol=1e-02, ) + + +# ifftshift +@handle_frontend_test( + fn_tree="jax.numpy.fft.ifftshift", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=(4,), array_api_dtypes=True + ), +) +def test_jax_numpy_ifftshift( + dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device +): + input_dtype, arr = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + x=arr[0], + axes=None, + ) + From 5146d2855b522c53724bce1af6d09a78982640dd Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Tue, 26 Sep 2023 04:28:03 +0000 Subject: [PATCH 02/12] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/jax/numpy/fft.py | 5 +---- .../test_ivy/test_frontends/test_jax/test_numpy/test_fft.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index 71c436ed11a90..f1c2097477da3 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -71,7 +71,6 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None): @to_ivy_arrays_and_back - def ifftshift(x, axes=None): # Check if an array if not ivy.is_array(x): @@ -79,7 +78,7 @@ def ifftshift(x, axes=None): # Get the shape of x shape = ivy.shape(x) - + # If axes is None, shift all axes if axes is None: axes = list(range(ivy.ndims(x))) @@ -98,7 +97,6 @@ def ifftshift(x, axes=None): return result - @to_ivy_arrays_and_back @with_unsupported_dtypes({"1.25.2 and below": ("float16", "bfloat16")}, "numpy") def rfft(a, n=None, axis=-1, norm=None): @@ -113,4 +111,3 @@ def rfft(a, n=None, axis=-1, norm=None): slices[axis] = slice(0, int(ivy.shape(result, as_array=True)[axis] // 2 + 1)) result = result[tuple(slices)] return result - diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 7d85fd451c744..0cb1af06f408b 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -304,4 +304,3 @@ def test_jax_numpy_rfft( atol=1e-04, rtol=1e-04, ) - From 5728e55f843e19d2da1bc0c22c81e914e91c4586 Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Tue, 26 Sep 2023 05:04:39 +0000 Subject: [PATCH 03/12] added the missing code --- .../test_frontends/test_jax/test_numpy/test_fft.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 0cb1af06f408b..2d9e610bb6466 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -254,6 +254,19 @@ def test_jax_numpy_ifftshift( dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device ): input_dtype, arr = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + x=arr[0], + axes=None, + ) + + # rfft @handle_frontend_test( From e8bc76ff794e50035294ce45a046d6d79f01606d Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Sat, 30 Sep 2023 11:19:47 +0000 Subject: [PATCH 04/12] deleted the unwanted code --- ivy/functional/frontends/torch/comparison_ops.py | 2 +- .../test_frontends/test_jax/test_numpy/test_fft.py | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/ivy/functional/frontends/torch/comparison_ops.py b/ivy/functional/frontends/torch/comparison_ops.py index 0481f00ddca57..4a52f22da10a6 100644 --- a/ivy/functional/frontends/torch/comparison_ops.py +++ b/ivy/functional/frontends/torch/comparison_ops.py @@ -292,7 +292,7 @@ def topk(input, k, dim=None, largest=True, sorted=True, *, out=None): gt = greater +ne = not_equal ge = greater_equal le = less_equal lt = less -ne = not_equal diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 2d9e610bb6466..3fe6ec6f1c9a5 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -242,7 +242,6 @@ def test_jax_numpy_ifft2( ) - # ifftshift @handle_frontend_test( fn_tree="jax.numpy.fft.ifftshift", @@ -267,7 +266,6 @@ def test_jax_numpy_ifftshift( ) - # rfft @handle_frontend_test( fn_tree="jax.numpy.fft.rfft", @@ -303,13 +301,6 @@ def test_jax_numpy_rfft( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - - test_values=True, - x=arr[0], - axes=None, - ) - - a=x[0], n=n, axis=axis, From 20663aec084ca2410eaf5322dab2a0d31b00b1ad Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Sat, 30 Sep 2023 11:20:43 +0000 Subject: [PATCH 05/12] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/torch/comparison_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/frontends/torch/comparison_ops.py b/ivy/functional/frontends/torch/comparison_ops.py index 4a52f22da10a6..0481f00ddca57 100644 --- a/ivy/functional/frontends/torch/comparison_ops.py +++ b/ivy/functional/frontends/torch/comparison_ops.py @@ -292,7 +292,7 @@ def topk(input, k, dim=None, largest=True, sorted=True, *, out=None): gt = greater -ne = not_equal ge = greater_equal le = less_equal lt = less +ne = not_equal From bda0555cad7fc2e4889700f131e37d343c903155 Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Wed, 4 Oct 2023 23:08:16 +0530 Subject: [PATCH 06/12] Update test_fft.py --- .../test_jax/test_numpy/test_fft.py | 42 +------------------ 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 3fe6ec6f1c9a5..8bef9d7ce3372 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -266,45 +266,5 @@ def test_jax_numpy_ifftshift( ) -# rfft -@handle_frontend_test( - fn_tree="jax.numpy.fft.rfft", - dtype_input_axis=helpers.dtype_values_axis( - available_dtypes=helpers.get_dtypes("float"), - num_arrays=1, - min_value=-1e5, - max_value=1e5, - min_num_dims=1, - min_dim_size=2, - allow_inf=False, - large_abs_safety_factor=2.5, - small_abs_safety_factor=2.5, - safety_factor_scale="log", - valid_axis=True, - force_int_axis=True, - ), - n=st.one_of( - st.integers(min_value=2, max_value=10), - st.just(None), - ), - norm=st.sampled_from(["backward", "ortho", "forward", None]), -) -def test_jax_numpy_rfft( - dtype_input_axis, n, norm, frontend, backend_fw, test_flags, fn_tree, on_device -): - input_dtype, x, axis = dtype_input_axis - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - a=x[0], - n=n, - axis=axis, - norm=norm, - atol=1e-04, - rtol=1e-04, - ) + From 83975fc7852ef2e96b3a1f3e47f088723e82f57a Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Wed, 4 Oct 2023 17:39:29 +0000 Subject: [PATCH 07/12] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_ivy/test_frontends/test_jax/test_numpy/test_fft.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 8bef9d7ce3372..09492abc3709d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -264,7 +264,3 @@ def test_jax_numpy_ifftshift( x=arr[0], axes=None, ) - - - - From b7bce228f9257074b8d9e4383a69cf0dde5d15e5 Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Thu, 5 Oct 2023 12:30:40 +0530 Subject: [PATCH 08/12] Update test_fft.py --- .../test_jax/test_numpy/test_fft.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index 09492abc3709d..b578cc8de6dcf 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -241,6 +241,47 @@ def test_jax_numpy_ifft2( rtol=1e-02, ) +# rfft +@handle_frontend_test( + fn_tree="jax.numpy.fft.rfft", + dtype_input_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + min_value=-1e5, + max_value=1e5, + min_num_dims=1, + min_dim_size=2, + allow_inf=False, + large_abs_safety_factor=2.5, + small_abs_safety_factor=2.5, + safety_factor_scale="log", + valid_axis=True, + force_int_axis=True, + ), + n=st.one_of( + st.integers(min_value=2, max_value=10), + st.just(None), + ), + norm=st.sampled_from(["backward", "ortho", "forward", None]), +) +def test_jax_numpy_rfft( + dtype_input_axis, n, norm, frontend, backend_fw, test_flags, fn_tree, on_device +): + input_dtype, x, axis = dtype_input_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + a=x[0], + n=n, + axis=axis, + norm=norm, + atol=1e-04, + rtol=1e-04, + ) # ifftshift @handle_frontend_test( From 5810bc7079ca38a2e5a396aa6222610f7a04782c Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Thu, 5 Oct 2023 07:01:34 +0000 Subject: [PATCH 09/12] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_jax/test_numpy/test_fft.py | 48 ++++++++++--------- 1 file changed, 25 insertions(+), 23 deletions(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py index b578cc8de6dcf..b2f59abc0e2d2 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py @@ -241,6 +241,31 @@ def test_jax_numpy_ifft2( rtol=1e-02, ) + +# ifftshift +@handle_frontend_test( + fn_tree="jax.numpy.fft.ifftshift", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), shape=(4,), array_api_dtypes=True + ), +) +def test_jax_numpy_ifftshift( + dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device +): + input_dtype, arr = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + x=arr[0], + axes=None, + ) + + # rfft @handle_frontend_test( fn_tree="jax.numpy.fft.rfft", @@ -282,26 +307,3 @@ def test_jax_numpy_rfft( atol=1e-04, rtol=1e-04, ) - -# ifftshift -@handle_frontend_test( - fn_tree="jax.numpy.fft.ifftshift", - dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("valid"), shape=(4,), array_api_dtypes=True - ), -) -def test_jax_numpy_ifftshift( - dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device -): - input_dtype, arr = dtype_and_x - helpers.test_frontend_function( - input_dtypes=input_dtype, - frontend=frontend, - backend_to_test=backend_fw, - test_flags=test_flags, - fn_tree=fn_tree, - on_device=on_device, - test_values=True, - x=arr[0], - axes=None, - ) From 30e8ee203c283f6d2c7cce6b8353365040aaab0f Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Thu, 12 Oct 2023 12:35:35 +0530 Subject: [PATCH 10/12] Updated the code Since Ivy's documentation does not have function 'ndim' , i am using the length of the shape instead. --- ivy/functional/frontends/jax/numpy/fft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index f1c2097477da3..c7b0a35dafb2d 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -81,7 +81,7 @@ def ifftshift(x, axes=None): # If axes is None, shift all axes if axes is None: - axes = list(range(ivy.ndims(x))) + axes = list(range(len(shape)) # Initialize a list to store the shift values shift_values = [] From 3c821f3eb37cc4b03e1b10299c1799937af15312 Mon Sep 17 00:00:00 2001 From: Vaishnavi Mudaliar Date: Thu, 12 Oct 2023 07:23:21 +0000 Subject: [PATCH 11/12] code updated --- ivy/functional/frontends/jax/numpy/fft.py | 2 +- ivy/functional/frontends/torch/comparison_ops.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/frontends/jax/numpy/fft.py b/ivy/functional/frontends/jax/numpy/fft.py index c7b0a35dafb2d..f8a0d24d1ad7c 100644 --- a/ivy/functional/frontends/jax/numpy/fft.py +++ b/ivy/functional/frontends/jax/numpy/fft.py @@ -81,7 +81,7 @@ def ifftshift(x, axes=None): # If axes is None, shift all axes if axes is None: - axes = list(range(len(shape)) + axes = list(range(len(shape))) # Initialize a list to store the shift values shift_values = [] diff --git a/ivy/functional/frontends/torch/comparison_ops.py b/ivy/functional/frontends/torch/comparison_ops.py index e4bb1e4b2382c..a40db0d334182 100644 --- a/ivy/functional/frontends/torch/comparison_ops.py +++ b/ivy/functional/frontends/torch/comparison_ops.py @@ -292,7 +292,7 @@ def topk(input, k, dim=None, largest=True, sorted=True, *, out=None): gt = greater +ne = not_equal ge = greater_equal le = less_equal lt = less -ne = not_equal From 6fe2b505a1dd34aa0d0fb66017126a8551cb74db Mon Sep 17 00:00:00 2001 From: ivy-branch Date: Thu, 12 Oct 2023 07:24:31 +0000 Subject: [PATCH 12/12] =?UTF-8?q?=F0=9F=A4=96=20Lint=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ivy/functional/frontends/torch/comparison_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/functional/frontends/torch/comparison_ops.py b/ivy/functional/frontends/torch/comparison_ops.py index a40db0d334182..e4bb1e4b2382c 100644 --- a/ivy/functional/frontends/torch/comparison_ops.py +++ b/ivy/functional/frontends/torch/comparison_ops.py @@ -292,7 +292,7 @@ def topk(input, k, dim=None, largest=True, sorted=True, *, out=None): gt = greater -ne = not_equal ge = greater_equal le = less_equal lt = less +ne = not_equal