Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT doc #535

Merged
merged 19 commits into from
Jul 20, 2024
166 changes: 149 additions & 17 deletions include/ddc/kernels/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,40 @@ static_assert(alignof(hipfftDoubleComplex) <= alignof(Kokkos::complex<double>));
#endif

namespace ddc {
// TODO : maybe transfert this somewhere else because Fourier space is not specific to FFT
/**
* @brief A templated tag representing a continuous dimension in the Fourier space associated to the original continuous dimension.
*
* @tparam The tag representing the original dimension.
*/
template <typename Dim>
struct Fourier;

// named arguments for FFT (and their default values)
enum class FFT_Direction { FORWARD, BACKWARD };
enum class FFT_Normalization { OFF, FORWARD, BACKWARD, ORTHO, FULL };
/**
* @brief A named argument to choose the direction of the FFT.
*
* @see kwArgs_impl, kwArgs_fft
*/
enum class FFT_Direction {
FORWARD, ///< Forward, corresponds to direct FFT up to normalization
BACKWARD ///< Backward, corresponds to inverse FFT up to normalization
};

/**
* @brief A named argument to choose the type of normalization of the FFT.
*
* @see kwArgs_impl, kwArgs_fft
*/
enum class FFT_Normalization {
OFF, ///< No normalization. Un-normalized FFT is sum_j f(x_j)*e^-ikx_j
FORWARD, ///< Multiply by 1/N for forward FFT, no normalization for backward FFT
BACKWARD, ///< No normalization for forward FFT, multiply by 1/N for backward FFT
ORTHO, ///< Multiply by 1/sqrt(N)
FULL /**<
* Multiply by dx/sqrt(2*pi) for forward FFT and dk/sqrt(2*pi) for backward
* FFT. It is aligned with the usual definition of the (continuous) Fourier transform
* 1/sqrt(2*pi)*int f(x)*e^-ikx*dx, and thus may be relevant for spectral analysis applications.
*/
};
} // namespace ddc

namespace ddc::detail::fft {
Expand Down Expand Up @@ -115,7 +142,11 @@ KOKKOS_FUNCTION constexpr T LastSelector(const T a, const T b)
return LastSelector<T, Dim, Second, Tail...>(a, b);
}

// transform_type : trait to determine the type of transformation (R2C, C2R, C2C...) <- no information about base type (float or double)
/**
* @brief A trait to identify the type of transformation (R2C, C2R, C2C...).
*
* It does not contain the information about the base type (float or double).
*/
enum class TransformType { R2R, R2C, C2R, C2C };

template <typename T1, typename T2>
Expand All @@ -142,6 +173,14 @@ struct transform_type<Kokkos::complex<T1>, Kokkos::complex<T2>>
static constexpr TransformType value = TransformType::C2C;
};

/**
* @brief A trait to get the TransformType for the input and output types.
*
* Internally check if T1 and T2 are Kokkos::complex<something> or not.
*
* @tparam T1 The input type.
* @tparam T2 The output type.
*/
template <typename T1, typename T2>
constexpr TransformType transform_type_v = transform_type<T1, T2>::value;

Expand Down Expand Up @@ -318,14 +357,26 @@ hipfftResult _hipfftExec([[maybe_unused]] LastArg lastArg, Args... args)
}
#endif

struct kwArgs_core
/*
* @brief A structure embedding the configuration of the impl FFT function: direction and type of normalization.
*
* @see FFT_impl
*/
struct kwArgs_impl
{
ddc::FFT_Direction
direction; // Only effective for C2C transform and for normalization BACKWARD and FORWARD
ddc::FFT_Normalization normalization;
};

// N,a,b from x_mesh
/**
* @brief Get the mesh size along a given dimension.
*
* @tparam DDim The dimension along which the mesh size is returned.
* @param x_mesh The mesh.
*
* @return The mesh size along the required dimension.
*/
template <typename DDim, typename... DDimX>
int N(ddc::DiscreteDomain<DDimX...> x_mesh)
{
Expand All @@ -335,14 +386,14 @@ int N(ddc::DiscreteDomain<DDimX...> x_mesh)
return static_cast<int>(x_mesh.template extent<DDim>());
}

// core
/// @brief Core internal function to perform the FFT.
template <typename Tin, typename Tout, typename ExecSpace, typename MemorySpace, typename... DDimX>
void core(
void impl(
ExecSpace const& execSpace,
Tout* out_data,
Tin* in_data,
ddc::DiscreteDomain<DDimX...> mesh,
const kwArgs_core& kwargs)
const kwArgs_impl& kwargs)
{
static_assert(
std::is_same_v<real_type_t<Tin>, float> || std::is_same_v<real_type_t<Tin>, double>,
Expand Down Expand Up @@ -559,6 +610,25 @@ void core(

namespace ddc {

/**
* @brief Initialize a Fourier discrete dimension.
*
* Initialize the (1D) discrete space representing the Fourier discrete dimension associated
* to the (1D) mesh passed as argument. It is a N-periodic PeriodicSampling with a periodic window of width 2*pi/dx.
*
* This value comes from the Nyquist-Shannon theorem: the period of the spectral domain is N*dk = 2*pi/dx.
* Adding to this the relations dx = (xmax-xmin)/(N-1), and dk = (kmax-kmin)/(N-1), we get kmax-kmin = 2*pi*(N-1)^2/N/(xmax-xmin),
* which is used in the implementation (xmax, xmin, kmin and kmax are the centers of lower and upper cells inside a single period of the meshes).
*
* @tparam DDimFx A PeriodicSampling representing the Fourier discrete dimension.
* @tparam DDimX The type of the original discrete dimension.
*
* @param x_mesh The DiscreteDomain representing the (1D) original mesh.
*
* @return The initialized Impl representing the discrete Fourier space.
*
* @see PeriodicSampling
*/
template <typename DDimFx, typename DDimX>
typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
ddc::DiscreteDomain<DDimX> x_mesh)
Expand All @@ -568,7 +638,7 @@ typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
"DDimX dimensions should derive from UniformPointSampling");
static_assert(
is_periodic_sampling_v<DDimFx>,
"DDimFx dimensions should derive from PeriodicPointSampling");
"DDimFx dimensions should derive from PeriodicSampling");
auto [impl, ddom] = DDimFx::template init<DDimFx>(
ddc::Coordinate<typename DDimFx::continuous_dimension_type>(0),
ddc::Coordinate<typename DDimFx::continuous_dimension_type>(
Expand All @@ -583,7 +653,21 @@ typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
return std::move(impl);
}

// FourierMesh, first element corresponds to mode 0
/**
* @brief Get the Fourier mesh.
*
* Compute the Fourier (or spectral) mesh on which the Discrete Fourier Transform of a
* discrete function is defined.
*
* @param x_mesh The DiscreteDomain representing the original mesh.
* @param C2C A flag indicating if a complex-to-complex DFT is going to be performed. Indeed,
* in this case the two meshes have same number of points, whereas for real-to-complex
* or complex-to-real DFT, each complex value of the Fourier-transformed function contains twice more
* information, and thus only half (actually Nx*Ny*(Nz/2+1) for 3D R2C FFT to take in account mode 0)
* values are needed (cf. DFT conjugate symmetry property for more information about this).
*
* @return The domain representing the Fourier mesh.
*/
template <typename... DDimFx, typename... DDimX>
ddc::DiscreteDomain<DDimFx...> FourierMesh(ddc::DiscreteDomain<DDimX...> x_mesh, bool C2C)
{
Expand All @@ -602,12 +686,38 @@ ddc::DiscreteDomain<DDimFx...> FourierMesh(ddc::DiscreteDomain<DDimX...> x_mesh,
ddc::detail::fft::N<DDimX>(x_mesh)))))...);
}

/**
* @brief A structure embedding the configuration of the exposed FFT function with the type of normalization.
*
* @see fft, ifft
*/
struct kwArgs_fft
{
ddc::FFT_Normalization normalization;
ddc::FFT_Normalization
normalization; ///< Enum member to identify the type of normalization performed
};

// FFT
/**
* @brief Perform a direct Fast Fourier Transform.
*
* Compute the discrete Fourier transform of a function using the specialized implementation for the Kokkos::ExecutionSpace
* of the FFT algorithm.
*
* @tparam Tin The type of the input elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
* @tparam Tout The type of the output elements (Kokkos::complex<float> or Kokkos::complex<double>).
* @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
* @tparam DDimX... The parameter pack of the original discrete dimensions.
* @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the FFT is performed. It determines which specialized
* backend is used (ie. fftw, cuFFT...).
* @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
* @tparam layout_in The layout of the Chunkspan representing the input discrete function.
* @tparam layout_out The layout of the Chunkspan representing the output discrete function.
*
* @param execSpace The Kokkos::ExecutionSpace on which the FFT is performed.
* @param out The output discrete function, represented as a ChunkSpan storing values on a spectral mesh.
* @param in The input discrete function, represented as a ChunkSpan storing values on a mesh.
* @param kwargs The kwArgs_fft configuring the FFT.
*/
template <
typename Tin,
typename Tout,
Expand Down Expand Up @@ -636,15 +746,37 @@ void fft(
(is_periodic_sampling_v<DDimFx> && ...),
"DDimFx dimensions should derive from PeriodicPointSampling");

ddc::detail::fft::core<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
ddc::detail::fft::impl<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
execSpace,
out.data_handle(),
in.data_handle(),
in.domain(),
{ddc::FFT_Direction::FORWARD, kwargs.normalization});
}

// iFFT (deduced from the fact that "in" is identified as a function on the Fourier space)
/**
* @brief Perform an inverse Fast Fourier Transform.
*
* Compute the inverse discrete Fourier transform of a spectral function using the specialized implementation for the Kokkos::ExecutionSpace
* of the iFFT algorithm.
*
* /!\ C2R iFFT does NOT preserve input !
*
* @tparam Tin The type of the input elements (Kokkos::complex<float> or Kokkos::complex<double>).
* @tparam Tout The type of the output elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
* @tparam DDimX... The parameter pack of the original discrete dimensions.
* @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
* @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the iFFT is performed. It determines which specialized
* backend is used (ie. fftw, cuFFT...).
* @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
* @tparam layout_in The layout of the Chunkspan representing the input discrete function.
* @tparam layout_out The layout of the Chunkspan representing the output discrete function.
*
* @param execSpace The Kokkos::ExecutionSpace on which the iFFT is performed.
* @param out The output discrete function, represented as a ChunkSpan storing values on a mesh.
* @param in The input discrete function, represented as a ChunkSpan storing values on a spectral mesh.
* @param kwargs The kwArgs_fft configuring the iFFT.
*/
template <
typename Tin,
typename Tout,
Expand Down Expand Up @@ -673,7 +805,7 @@ void ifft(
(is_periodic_sampling_v<DDimFx> && ...),
"DDimFx dimensions should derive from PeriodicPointSampling");

ddc::detail::fft::core<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
ddc::detail::fft::impl<Tin, Tout, ExecSpace, MemorySpace, DDimX...>(
execSpace,
out.data_handle(),
in.data_handle(),
Expand Down