Skip to content

Commit

Permalink
Simplify fft full normalization (#663)
Browse files Browse the repository at this point in the history
* Make a 1d full normalization

* Remove usage of the select function

* Rename template parameters

* Use rlength

* Reorganize computation

* Reuse forward_full_norm_coef

* Use value instead of get

* Review from Yuuichi
  • Loading branch information
tpadioleau authored Oct 15, 2024
1 parent 1a436bb commit fc0aa21
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions include/ddc/kernels/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,19 @@ void rescale(
});
}

template <class DDim>
Real forward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
{
return rlength(ddom) / Kokkos::sqrt(2 * Kokkos::numbers::pi_v<Real>)
/ (ddom.extents() - 1).value();
}

template <class DDim>
Real backward_full_norm_coef(DiscreteDomain<DDim> const& ddom) noexcept
{
return 1 / (forward_full_norm_coef(ddom) * ddom.extents().value());
}

/// @brief Core internal function to perform the FFT.
template <
typename Tin,
Expand Down Expand Up @@ -255,25 +268,16 @@ void impl(

// The FULL normalization is mesh-dependant and thus handled by DDC
if (kwargs.normalization == ddc::FFT_Normalization::FULL) {
real_type_t<Tout> norm_coef;
Real norm_coef;
if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
norm_coef
= (((coordinate(ddc::select<DDimIn>(in.domain()).back())
- coordinate(ddc::select<DDimIn>(in.domain()).front()))
/ (ddc::get<DDimIn>(in.domain().extents()) - 1)
/ Kokkos::sqrt(2 * Kokkos::numbers::pi))
* ...);
DiscreteDomain<DDimIn...> const ddom_in = in.domain();
norm_coef = (forward_full_norm_coef(DiscreteDomain<DDimIn>(ddom_in)) * ...);
} else {
norm_coef
= ((Kokkos::sqrt(2 * Kokkos::numbers::pi)
/ (coordinate(ddc::select<DDimOut>(out.domain()).back())
- coordinate(ddc::select<DDimOut>(out.domain()).front()))
* (ddc::get<DDimOut>(out.domain().extents()) - 1)
/ ddc::get<DDimOut>(out.domain().extents()))
* ...);
DiscreteDomain<DDimOut...> const ddom_out = out.domain();
norm_coef = (backward_full_norm_coef(DiscreteDomain<DDimOut>(ddom_out)) * ...);
}

rescale(exec_space, out, norm_coef);
rescale(exec_space, out, static_cast<real_type_t<Tout>>(norm_coef));
}
}

Expand Down

0 comments on commit fc0aa21

Please sign in to comment.