From f0dd60ba3eafa049863e11506d3b48ee33c5168a Mon Sep 17 00:00:00 2001 From: Abdulkadir Gokce Date: Tue, 4 Jul 2023 15:35:13 +0300 Subject: [PATCH] Add fftshift to Paddle Frontend --- ivy/functional/frontends/paddle/fft.py | 30 +++++++++++++++++++ .../test_paddle/test_paddle_fft.py | 25 ++++++++++++++++ 2 files changed, 55 insertions(+) diff --git a/ivy/functional/frontends/paddle/fft.py b/ivy/functional/frontends/paddle/fft.py index b2f4214d2a971..eb9d9dec4c825 100644 --- a/ivy/functional/frontends/paddle/fft.py +++ b/ivy/functional/frontends/paddle/fft.py @@ -14,3 +14,33 @@ def fft(x, n=None, axis=-1.0, norm="backward", name=None): ret = ivy.fft(ivy.astype(x, "complex128"), axis, norm=norm, n=n) return ivy.astype(ret, x.dtype) + + +@with_supported_dtypes( + { + "2.5.0 and below": ( + "int32", + "int64", + "float32", + "float64", + "complex64", + "complex128", + ) + }, + "paddle", +) +@to_ivy_arrays_and_back +def fftshift(x, axes=None, name=None): + shape = x.shape + + if axes is None: + axes = tuple(range(x.ndim)) + shifts = [(dim // 2) for dim in shape] + elif isinstance(axes, int): + shifts = shape[axes] // 2 + else: + shifts = ivy.concat([shape[ax] // 2 for ax in axes]) + + roll = ivy.roll(x, shifts, axis=axes) + + return roll diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_paddle_fft.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_paddle_fft.py index 8e7e08b62f4a4..0fa652f27d4f3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_paddle_fft.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_paddle_fft.py @@ -42,3 +42,28 @@ def test_paddle_fft( axis=axis, norm=norm, ) + + +@handle_frontend_test( + fn_tree="paddle.fft.fftshift", + dtype_x_axis=helpers.dtype_values_axis( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-10, + max_value=10, + min_num_dims=1, + valid_axis=True, + force_int_axis=True, + ), +) +def test_paddle_fttshift(dtype_x_axis, frontend, test_flags, fn_tree, on_device): + input_dtype, x, axes = dtype_x_axis + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + test_values=True, + x=x[0], + axes=axes, + )