Skip to content

Commit

Permalink
FFT doc (#535)
Browse files Browse the repository at this point in the history
* init

* wip

* rc

* restore

* enums

* fixes

* autoreview

* precision

* Update include/ddc/kernels/fft.hpp

* precisions

* Update include/ddc/kernels/fft.hpp

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Thomas Padioleau <thomas.padioleau@cea.fr>

* remove "spatial" terminology

* Update include/ddc/kernels/fft.hpp

* Update include/ddc/kernels/fft.hpp

* core -> impl

---------

Co-authored-by: Thomas Padioleau <thomas.padioleau@cea.fr>
  • Loading branch information
blegouix and tpadioleau authored Jul 20, 2024
1 parent 24e28f0 commit 21fa29b
Showing 1 changed file with 149 additions and 17 deletions.
166 changes: 149 additions & 17 deletions include/ddc/kernels/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,40 @@ static_assert(alignof(hipfftDoubleComplex) <= alignof(Kokkos::complex<double>));
#endif

namespace ddc {
// TODO : maybe transfert this somewhere else because Fourier space is not specific to FFT
/**
* @brief A templated tag representing a continuous dimension in the Fourier space associated to the original continuous dimension.
*
* @tparam The tag representing the original dimension.
*/
template <typename Dim>
struct Fourier;

// named arguments for FFT (and their default values)
enum class FFT_Direction { FORWARD, BACKWARD };
enum class FFT_Normalization { OFF, FORWARD, BACKWARD, ORTHO, FULL };
/**
* @brief A named argument to choose the direction of the FFT.
*
* @see kwArgs_impl, kwArgs_fft
*/
enum class FFT_Direction {
FORWARD, ///< Forward, corresponds to direct FFT up to normalization
BACKWARD ///< Backward, corresponds to inverse FFT up to normalization
};

/**
* @brief A named argument to choose the type of normalization of the FFT.
*
* @see kwArgs_impl, kwArgs_fft
*/
enum class FFT_Normalization {
OFF, ///< No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j
FORWARD, ///< Multiply by 1/N for forward FFT, no normalization for backward FFT
BACKWARD, ///< No normalization for forward FFT, multiply by 1/N for backward FFT
ORTHO, ///< Multiply by 1/sqrt(N)
FULL /**<
* Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward
* FFT. It is aligned with the usual definition of the (continuous) Fourier transform
* 1/sqrt(2*pi)*int f(x)*e^-ikx*dx, and thus may be relevant for spectral analysis applications.
*/
};
} // namespace ddc

namespace ddc::detail::fft {
Expand Down Expand Up @@ -115,7 +142,11 @@ KOKKOS_FUNCTION constexpr T LastSelector(const T a, const T b)
return LastSelector<T, Dim, Second, Tail...>(a, b);
}

// transform_type : trait to determine the type of transformation (R2C, C2R, C2C...) <- no information about base type (float or double)
/**
* @brief A trait to identify the type of transformation (R2C, C2R, C2C...).
*
* It does not contain the information about the base type (float or double).
*/
enum class TransformType { R2R, R2C, C2R, C2C };

template <typename T1, typename T2>
Expand All @@ -142,6 +173,14 @@ struct transform_type<Kokkos::complex<T1>, Kokkos::complex<T2>>
static constexpr TransformType value = TransformType::C2C;
};

/**
* @brief A trait to get the TransformType for the input and output types.
*
* Internally check if T1 and T2 are Kokkos::complex<something> or not.
*
* @tparam T1 The input type.
* @tparam T2 The output type.
*/
template <typename T1, typename T2>
constexpr TransformType transform_type_v = transform_type<T1, T2>::value;

Expand Down Expand Up @@ -318,14 +357,26 @@ hipfftResult _hipfftExec([[maybe_unused]] LastArg lastArg, Args... args)
}
#endif

struct kwArgs_core
/*
* @brief A structure embedding the configuration of the impl FFT function: direction and type of normalization.
*
* @see FFT_impl
*/
struct kwArgs_impl
{
ddc::FFT_Direction
direction; // Only effective for C2C transform and for normalization BACKWARD and FORWARD
ddc::FFT_Normalization normalization;
};

// N,a,b from x_mesh
/**
* @brief Get the mesh size along a given dimension.
*
* @tparam DDim The dimension along which the mesh size is returned.
* @param x_mesh The mesh.
*
* @return The mesh size along the required dimension.
*/
template <typename DDim, typename... DDimX>
int N(ddc::DiscreteDomain<DDimX...> x_mesh)
{
Expand All @@ -335,14 +386,14 @@ int N(ddc::DiscreteDomain<DDimX...> x_mesh)
return static_cast<int>(x_mesh.template extent<DDim>());
}

// core
/// @brief Core internal function to perform the FFT.
template <typename Tin, typename Tout, typename ExecSpace, typename MemorySpace, typename... DDimX>
void core(
void impl(
ExecSpace const& execSpace,
Tout* out_data,
Tin* in_data,
ddc::DiscreteDomain<DDimX...> mesh,
const kwArgs_core& kwargs)
const kwArgs_impl& kwargs)
{
static_assert(
std::is_same_v<real_type_t<Tin>, float> || std::is_same_v<real_type_t<Tin>, double>,
Expand Down Expand Up @@ -559,6 +610,25 @@ void core(

namespace ddc {

/**
* @brief Initialize a Fourier discrete dimension.
*
* Initialize the (1D) discrete space representing the Fourier discrete dimension associated
* to the (1D) mesh passed as argument. It is a N-periodic PeriodicSampling with a periodic window of width 2*pi/dx.
*
* This value comes from the Nyquist-Shannon theorem: the period of the spectral domain is N*dk = 2*pi/dx.
* Adding to this the relations dx = (xmax-xmin)/(N-1), and dk = (kmax-kmin)/(N-1), we get kmax-kmin = 2*pi*(N-1)^2/N/(xmax-xmin),
* which is used in the implementation (xmax, xmin, kmin and kmax are the centers of lower and upper cells inside a single period of the meshes).
*
* @tparam DDimFx A PeriodicSampling representing the Fourier discrete dimension.
* @tparam DDimX The type of the original discrete dimension.
*
* @param x_mesh The DiscreteDomain representing the (1D) original mesh.
*
* @return The initialized Impl representing the discrete Fourier space.
*
* @see PeriodicSampling
*/
template <typename DDimFx, typename DDimX>
typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
ddc::DiscreteDomain<DDimX> x_mesh)
Expand All @@ -568,7 +638,7 @@ typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
"DDimX dimensions should derive from UniformPointSampling");
static_assert(
is_periodic_sampling_v<DDimFx>,
"DDimFx dimensions should derive from PeriodicPointSampling");
"DDimFx dimensions should derive from PeriodicSampling");
auto [impl, ddom] = DDimFx::template init<DDimFx>(
ddc::Coordinate<typename DDimFx::continuous_dimension_type>(0),
ddc::Coordinate<typename DDimFx::continuous_dimension_type>(
Expand All @@ -583,7 +653,21 @@ typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
return std::move(impl);
}

// FourierMesh, first element corresponds to mode 0
/**
* @brief Get the Fourier mesh.
*
* Compute the Fourier (or spectral) mesh on which the Discrete Fourier Transform of a
* discrete function is defined.
*
* @param x_mesh The DiscreteDomain representing the original mesh.
* @param C2C A flag indicating if a complex-to-complex DFT is going to be performed. Indeed,
* in this case the two meshes have same number of points, whereas for real-to-complex
* or complex-to-real DFT, each complex value of the Fourier-transformed function contains twice more
* information, and thus only half (actually Nx*Ny*(Nz/2+1) for 3D R2C FFT to take in account mode 0)
* values are needed (cf. DFT conjugate symmetry property for more information about this).
*
* @return The domain representing the Fourier mesh.
*/
template <typename... DDimFx, typename... DDimX>
ddc::DiscreteDomain<DDimFx...> FourierMesh(ddc::DiscreteDomain<DDimX...> x_mesh, bool C2C)
{
Expand All @@ -602,12 +686,38 @@ ddc::DiscreteDomain<DDimFx...> FourierMesh(ddc::DiscreteDomain<DDimX...> x_mesh,
ddc::detail::fft::N<DDimX>(x_mesh)))))...);
}

/**
* @brief A structure embedding the configuration of the exposed FFT function with the type of normalization.
*
* @see fft, ifft
*/
struct kwArgs_fft
{
ddc::FFT_Normalization normalization;
ddc::FFT_Normalization
normalization; ///< Enum member to identify the type of normalization performed
};

// FFT
/**
* @brief Perform a direct Fast Fourier Transform.
*
* Compute the discrete Fourier transform of a function using the specialized implementation for the Kokkos::ExecutionSpace
* of the FFT algorithm.
*
* @tparam Tin The type of the input elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
* @tparam Tout The type of the output elements (Kokkos::complex<float> or Kokkos::complex<double>).
* @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
* @tparam DDimX... The parameter pack of the original discrete dimensions.
* @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the FFT is performed. It determines which specialized
* backend is used (ie. fftw, cuFFT...).
* @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
* @tparam layout_in The layout of the Chunkspan representing the input discrete function.
* @tparam layout_out The layout of the Chunkspan representing the output discrete function.
*
* @param execSpace The Kokkos::ExecutionSpace on which the FFT is performed.
* @param out The output discrete function, represented as a ChunkSpan storing values on a spectral mesh.
* @param in The input discrete function, represented as a ChunkSpan storing values on a mesh.
* @param kwargs The kwArgs_fft configuring the FFT.
*/
template <
typename Tin,
typename Tout,
Expand Down Expand Up @@ -636,15 +746,37 @@ void fft(
(is_periodic_sampling_v<DDimFx> && ...),
"DDimFx dimensions should derive from PeriodicPointSampling");

ddc::detail::fft::core<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
ddc::detail::fft::impl<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
execSpace,
out.data_handle(),
in.data_handle(),
in.domain(),
{ddc::FFT_Direction::FORWARD, kwargs.normalization});
}

// iFFT (deduced from the fact that "in" is identified as a function on the Fourier space)
/**
* @brief Perform an inverse Fast Fourier Transform.
*
* Compute the inverse discrete Fourier transform of a spectral function using the specialized implementation for the Kokkos::ExecutionSpace
* of the iFFT algorithm.
*
* /!\ C2R iFFT does NOT preserve input !
*
* @tparam Tin The type of the input elements (Kokkos::complex<float> or Kokkos::complex<double>).
* @tparam Tout The type of the output elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
* @tparam DDimX... The parameter pack of the original discrete dimensions.
* @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
* @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the iFFT is performed. It determines which specialized
* backend is used (ie. fftw, cuFFT...).
* @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
* @tparam layout_in The layout of the Chunkspan representing the input discrete function.
* @tparam layout_out The layout of the Chunkspan representing the output discrete function.
*
* @param execSpace The Kokkos::ExecutionSpace on which the iFFT is performed.
* @param out The output discrete function, represented as a ChunkSpan storing values on a mesh.
* @param in The input discrete function, represented as a ChunkSpan storing values on a spectral mesh.
* @param kwargs The kwArgs_fft configuring the iFFT.
*/
template <
typename Tin,
typename Tout,
Expand Down Expand Up @@ -673,7 +805,7 @@ void ifft(
(is_periodic_sampling_v<DDimFx> && ...),
"DDimFx dimensions should derive from PeriodicPointSampling");

ddc::detail::fft::core<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
ddc::detail::fft::impl<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
execSpace,
out.data_handle(),
in.data_handle(),
Expand Down

0 comments on commit 21fa29b

Please sign in to comment.