Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(frontends): Implement rfft in PaddlePaddle frontend and fix fft for Tensorflow backend #19454

Merged
merged 12 commits into from
Sep 28, 2023
9 changes: 8 additions & 1 deletion ivy/functional/backends/tensorflow/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,9 @@ def _ifft_norm(
raise ivy.utils.exceptions.IvyError(f"Unrecognized normalization mode {norm}")


@with_supported_dtypes({"2.13.0 and below": ("complex",)}, backend_version)
@with_supported_dtypes(
{"2.13.0 and below": ("complex", "float32", "float64")}, backend_version
)
def fft(
x: Union[tf.Tensor, tf.Variable],
dim: int,
Expand All @@ -658,6 +660,11 @@ def fft(
n: Union[int, Tuple[int]] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
# ToDo: Remove conversion from float to complex when casting mode is working
if x.dtype == "float32":
x = tf.cast(x, tf.complex64)
elif x.dtype == "float64":
x = tf.cast(x, tf.complex128)
if not isinstance(dim, int):
raise ivy.utils.exceptions.IvyError(
f"Expecting <class 'int'> instead of {type(dim)}"
Expand Down
6 changes: 6 additions & 0 deletions ivy/functional/frontends/paddle/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,12 @@ def irfftn(x, s=None, axes=None, norm="backward", name=None):
return result_t


@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle")
@to_ivy_arrays_and_back
def rfft(x, n=None, axis=-1, norm="backward", name=None):
return ivy.dft(x, axis=axis, inverse=False, onesided=True, dft_length=n, norm=norm)


@to_ivy_arrays_and_back
def rfftfreq(n, d=1.0, dtype=None, name=None):
dtype = ivy.default_dtype()
Expand Down
41 changes: 41 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_paddle/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,47 @@ def test_paddle_irfftn(
)


# rfft
@handle_frontend_test(
fn_tree="paddle.fft.rfft",
dtype_input_axis=helpers.dtype_values_axis(
available_dtypes=helpers.get_dtypes("valid"),
min_num_dims=1,
min_dim_size=2,
shape=helpers.get_shape(
min_num_dims=1,
max_num_dims=2,
min_dim_size=2,
max_dim_size=4,
),
large_abs_safety_factor=12,
small_abs_safety_factor=12,
safety_factor_scale="log",
force_int_axis=True,
valid_axis=True,
allow_neg_axes=True,
),
norm=st.sampled_from(["backward", "ortho", "forward"]),
n=st.integers(min_value=2, max_value=10) | st.none(),
)
def test_paddle_rfft(
dtype_input_axis, norm, n, frontend, backend_fw, test_flags, fn_tree, on_device
):
input_dtype, x, axis = dtype_input_axis
helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
x=x[0],
n=n,
axis=axis,
norm=norm,
)


@handle_frontend_test(
fn_tree="paddle.fft.rfftfreq",
n=st.integers(min_value=1, max_value=1000),
Expand Down
Loading