Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: irfft2 function and test implemented. #23353

Merged
merged 13 commits into from
Sep 22, 2023
Merged
49 changes: 45 additions & 4 deletions ivy/functional/frontends/paddle/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,19 @@ def fftshift(x, axes=None, name=None):
"paddle",
)
@to_ivy_arrays_and_back
def hfft(x, n=None, axis=-1, norm="backward", name=None):
def hfft(x, n=None, axes=-1, norm="backward", name=None):
"""Compute the FFT of a signal that has Hermitian symmetry, resulting in a real
spectrum."""
# Determine the input shape and axis length
input_shape = x.shape
input_len = input_shape[axis]
input_len = input_shape[axes]

# Calculate n if not provided
if n is None:
n = 2 * (input_len - 1)

# Perform the FFT along the specified axis
result = ivy.fft(x, axis, n=n, norm=norm)
result = ivy.fft(x, axes, n=n, norm=norm)

return ivy.real(result)

Expand Down Expand Up @@ -146,7 +146,48 @@ def irfft(x, n=None, axis=-1.0, norm="backward", name=None):


@with_supported_dtypes(
{"2.5.1 and below": ("complex64", "complex128")},
{
"2.5.1 and below": (
"int32",
"int64",
"float16",
"float32",
"float64",
"complex64",
"complex128",
)
},
"paddle",
)
@to_ivy_arrays_and_back
def irfft2(x, s=None, axes=(-2, -1), norm="backward"):
# Handle values if None
if s is None:
s = x.shape
if axes is None:
axes = (-2, -1)

# Calculate the normalization factor 'n' based on the shape 's'
n = ivy.prod(ivy.array(s))

result = ivy.ifftn(x, dim=axes[0], norm=norm)

# Normalize the result based on the 'norm' parameter
if norm == "backward":
result /= n
elif norm == "forward":
result *= n
elif norm == "ortho":
result /= ivy.sqrt(n)
return result

@with_supported_dtypes(
{
"2.5.1 and below": (
"complex64",
"complex128"
)
},
"paddle",
)
@to_ivy_arrays_and_back
Expand Down
58 changes: 57 additions & 1 deletion ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from ivy_tests.test_ivy.helpers import handle_frontend_test


# Custom Hypothesis strategy for generating sequences of 2 integers
def sequence_of_two_integers():
return st.lists(st.integers(), min_size=2, max_size=2)


@handle_frontend_test(
fn_tree="paddle.fft.fft",
dtype_x_axis=helpers.dtype_values_axis(
Expand Down Expand Up @@ -260,9 +265,60 @@ def test_paddle_irfft(
n=n,
axis=axis,
norm=norm,
)


@handle_frontend_test(
fn_tree="paddle.fft.irfft2",
dtype_x_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("valid"),
min_value=-10,
max_value=10,
min_num_dims=2,
valid_axis=True,
force_int_axis=True,
)
),
)
@given(st.data())
def test_paddle_irfft2(
data,
dtype_x_axis,
frontend,
test_flags,
fn_tree,
on_device,
backend_fw,
):
input_dtype, x, axes = dtype_x_axis
for norm in ["backward", "forward", "ortho"]:
s_values = data.draw(s_strategy)
axes_values = data.draw(axes_strategy)

# Ensure s and axes are sequences of 2 integers
assert len(s_values) == 2
assert len(axes_values) == 2

# Convert s and axes to tuples as needed
s = tuple(s_values)
axes = tuple(axes_values)

helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
test_values=True,
x=x[0],
s=s,
axes=axes,
norm=norm,
)

# Use the custom strategy for s and axes
axes_strategy = sequence_of_two_integers()
s_strategy = sequence_of_two_integers()


@handle_frontend_test(
Expand Down
Loading