Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT doc #535

Merged
merged 19 commits into from
Jul 20, 2024
195 changes: 184 additions & 11 deletions include/ddc/kernels/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,40 @@ static_assert(alignof(hipfftDoubleComplex) <= alignof(Kokkos::complex<double>));

namespace ddc {
// TODO : maybe transfert this somewhere else because Fourier space is not specific to FFT
blegouix marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief A templated tag representing a continuous dimension in the Fourier space associated to a continuous spatial dimension.
*
* @tparam The tag representing the spatial dimensions.
blegouix marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename Dim>
struct Fourier;

// named arguments for FFT (and their default values)
enum class FFT_Direction { FORWARD, BACKWARD };
enum class FFT_Normalization { OFF, FORWARD, BACKWARD, ORTHO, FULL };
/**
* @brief A named argument to choose the direction of the FFT.
*
* @see kwArgs_core, kwArgs_fft
*/
enum class FFT_Direction {
FORWARD, ///< Forward, corresponds to direct FFT up to normalization
BACKWARD ///< Backward, corresponds to inverse FFT up to normalization
};

/**
* @brief A named argument to choose the type of normalization of the FFT.
*
* @see kwArgs_core, kwArgs_fft
*/
enum class FFT_Normalization {
OFF, ///< No normalization
FORWARD, ///< Multiply by 1/N for forward FFT, no normalization for backward FFT
BACKWARD, ///< No normalization for forward FFT, multiply by 1/N for backward FFT
ORTHO, ///< Multiply by 1/sqrt(N)
FULL /**<
* Multiply by (b-a)/N/sqrt(2*pi) for forward FFT and sqrt(2*pi)/(b-a) for forward
blegouix marked this conversation as resolved.
Show resolved Hide resolved
* FFT. It is aligned with the usual definition of the (continuous) Fourier transform
* 1/sqrt(2*pi)*int f(x)*e^-ikx*dx, and thus preserves the gaussian function exp(-x^2/2) numerically.
*/
blegouix marked this conversation as resolved.
Show resolved Hide resolved
};
} // namespace ddc

namespace ddc::detail::fft {
Expand Down Expand Up @@ -115,7 +143,12 @@ KOKKOS_FUNCTION constexpr T LastSelector(const T a, const T b)
return LastSelector<T, Dim, Second, Tail...>(a, b);
}

// transform_type : trait to determine the type of transformation (R2C, C2R, C2C...) <- no information about base type (float or double)
// transform_type :
blegouix marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief A trait to identify the type of transformation (R2C, C2R, C2C...).
*
* It does not contain the information about the base type (float or double).
*/
enum class TransformType { R2R, R2C, C2R, C2C };

template <typename T1, typename T2>
Expand All @@ -142,6 +175,14 @@ struct transform_type<Kokkos::complex<T1>, Kokkos::complex<T2>>
static constexpr TransformType value = TransformType::C2C;
};

/**
* @brief A trait to get the TransformType for the input and output types.
*
* Internally check if T1 and T2 are Kokkos::complex<something> or not.
*
* @tparam T1 The input type.
* @tparam T2 The output type.
*/
template <typename T1, typename T2>
constexpr TransformType transform_type_v = transform_type<T1, T2>::value;

Expand Down Expand Up @@ -318,6 +359,11 @@ hipfftResult _hipfftExec([[maybe_unused]] LastArg lastArg, Args... args)
}
#endif

/*
* @brief A structure embedding the configuration of the core FFT function: direction and type of normalization.
*
* @see FFT_core
*/
struct kwArgs_core
{
ddc::FFT_Direction
Expand All @@ -326,15 +372,42 @@ struct kwArgs_core
};

// N,a,b from x_mesh
blegouix marked this conversation as resolved.
Show resolved Hide resolved
/**
* @brief Get the mesh size along a given dimension.
*
* @tparam DDim The dimension along which the mesh size is returned.
* @param x_mesh The mesh.
*
* @return The mesh size along the required dimension.
*/
template <typename DDim, typename... DDimX>
int N(ddc::DiscreteDomain<DDimX...> x_mesh)
{
static_assert(
(is_uniform_point_sampling_v<DDimX> && ...),
"DDimX dimensions should derive from UniformPointSampling");
return ddc::get<DDim>(x_mesh.extents());
return static_cast<int>(x_mesh.template extent<DDim>());
}

/**
* @brief Get the lower boundary coordinate along a given dimension.
*
* The lower boundary of the spatial domain (which appears in Nyquist-Shannon theorem) is not
blegouix marked this conversation as resolved.
Show resolved Hide resolved
* xmin=ddc::coordinate(x_mesh.front()). Indeed, this coordinate identifies the lower cell, but
* the lower boundary is the left side of this lowest cell, which is a = xmin - cell_size/2, with
blegouix marked this conversation as resolved.
Show resolved Hide resolved
* cell_size = (b-a)/N. It leads to a = xmin-(b-a)/2N. The same derivation for the
* upper boundary coordinate gives b = xmax+(b-a)/2N. Inverting this linear system leads to:
*
* a = ((2N-1)*xmin-xmax)/2/(N-1)
* b = ((2N-1)*xmax-xmin)/2/(N-1)
*
* The current function implements the first equation.
*
* @tparam DDim The dimension along which the lower cell coordinate of the Fourier mesh is returned.
* @param x_mesh The spatial mesh.
*
* @return The mesh size along the required dimension.
blegouix marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename DDim, typename... DDimX>
double a(ddc::DiscreteDomain<DDimX...> x_mesh)
{
Expand All @@ -346,6 +419,25 @@ double a(ddc::DiscreteDomain<DDimX...> x_mesh)
/ 2 / (N<DDim>(x_mesh) - 1);
}

/**
* @brief Get the upper boundary coordinate along a given dimension.
*
* The upper boundary of the spatial domain (which appears in Nyquist-Shannon theorem) is not
* xmax=ddc::coordinate(x_mesh.back()). Indeed, this coordinate identifies the upper cell, but
* the upper boundary is the right side of this upper cell, which is b = xmax + cell_size/2, with
* cell_size = (b-a)/N. It leads to b = xmax+(b-a)/2N. The same derivation for the
* lower boundary coordinate gives a = xmin-(b-a)/2N. Inverting this linear system leads to:
*
* a = ((2N-1)*xmin-xmax)/2/(N-1)
* b = ((2N-1)*xmax-xmin)/2/(N-1)
*
* The current function implements the second equation.
*
* @tparam DDim The dimension along which the upper cell coordinate of the Fourier mesh is returned.
* @param x_mesh The spatial mesh.
*
* @return The mesh size along the required dimension.
blegouix marked this conversation as resolved.
Show resolved Hide resolved
*/
template <typename DDim, typename... DDimX>
double b(ddc::DiscreteDomain<DDimX...> x_mesh)
{
Expand All @@ -357,7 +449,7 @@ double b(ddc::DiscreteDomain<DDimX...> x_mesh)
/ 2 / (N<DDim>(x_mesh) - 1);
}

// core
/// @brief Core internal function to perform the FFT.
template <typename Tin, typename Tout, typename ExecSpace, typename MemorySpace, typename... DDimX>
void core(
ExecSpace const& execSpace,
Expand Down Expand Up @@ -581,6 +673,26 @@ void core(

namespace ddc {

/**
* @brief Initialize a discrete Fourier space.
blegouix marked this conversation as resolved.
Show resolved Hide resolved
*
* Initialize the (1D) discrete space representing the Fourier discrete dimension associated
* to the (1D) spatial mesh passed as argument. It is a N-periodic PeriodicSampling defined between
* ka=0 and kb=2*N/(b-a)*pi.
blegouix marked this conversation as resolved.
Show resolved Hide resolved
*
* This value for kb comes from the Nyquist-Shannon theorem: the period of the spectral domain
* is kb-ka = 2*pi/cell_size = 2*pi*N/(b-a). The PeriodicSampling then contains cells between coordinates
* k=0 and k=2*pi*(N-1)/(b-a), because the cell at coordinate k=2*pi*N/(b-a) is a periodic point (f(ka)=f(kb)).
blegouix marked this conversation as resolved.
Show resolved Hide resolved
*
* @tparam DDimFx A PeriodicSampling representing the Fourier discrete dimension.
* @tparam DDimX The type of the spatial discrete dimension.
*
* @param x_mesh The DiscreteDomain representing the (1D) spatial mesh.
*
* @return The initialized Impl representing the discrete Fourier space.
*
* @see PeriodicSampling
*/
template <typename DDimFx, typename DDimX>
typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
ddc::DiscreteDomain<DDimX> x_mesh)
Expand All @@ -590,7 +702,7 @@ typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
"DDimX dimensions should derive from UniformPointSampling");
static_assert(
is_periodic_sampling_v<DDimFx>,
"DDimFx dimensions should derive from PeriodicPointSampling");
"DDimFx dimensions should derive from PeriodicSampling");
auto [impl, ddom] = DDimFx::template init<DDimFx>(
ddc::Coordinate<typename DDimFx::continuous_dimension_type>(0),
ddc::Coordinate<typename DDimFx::continuous_dimension_type>(
Expand All @@ -602,7 +714,22 @@ typename DDimFx::template Impl<DDimFx, Kokkos::HostSpace> init_fourier_space(
return std::move(impl);
}

// FourierMesh, first element corresponds to mode 0
/**
* @brief Get the Fourier mesh.
*
* Compute the Fourier (or spectral) mesh on which the Discrete Fourier Transform of a
* spatial discrete function is defined.
*
* The uid identifies the mode (ie. ddc::DiscreteElement<DDimFx>(0) corresponds to mode 0).
blegouix marked this conversation as resolved.
Show resolved Hide resolved
*
* @param x_mesh The DiscreteDomain representing the spatial mesh.
* @param C2C A flag indicating if a complex-to-complex DFT is going to be performed. Indeed,
* in this case the spatial and spectral meshes have same number of points, whereas for real-to-complex
* or complex-to-real DFT, each complex value of the Fourier-transformed function contains twice more
* information, and thus only N/2+1 points are needed.
*
* @return The domain representing the Fourier mesh.
*/
template <typename... DDimFx, typename... DDimX>
ddc::DiscreteDomain<DDimFx...> FourierMesh(ddc::DiscreteDomain<DDimX...> x_mesh, bool C2C)
{
Expand All @@ -621,12 +748,38 @@ ddc::DiscreteDomain<DDimFx...> FourierMesh(ddc::DiscreteDomain<DDimX...> x_mesh,
ddc::detail::fft::N<DDimX>(x_mesh)))))...);
}

/**
* @brief A structure embedding the configuration of the exposed FFT function with the type of normalization.
*
* @see fft, ifft
*/
struct kwArgs_fft
{
ddc::FFT_Normalization normalization;
ddc::FFT_Normalization
normalization; ///< Enum member to identify the type of normalization performed
};

// FFT
/**
* @brief Perform a direct Fast Fourier Transform.
*
* Compute the discrete Fourier transform of a spatial function using the specialized implementation for the Kokkos::ExecutionSpace
* of the FFT algorithm.
*
* @tparam Tin The type of the input elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
* @tparam Tout The type of the output elements (Kokkos::complex<float> or Kokkos::complex<double>).
* @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
* @tparam DDimX... The parameter pack of the spatial discrete dimensions.
* @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the FFT is performed. It determines which specialized
* backend is used (ie. fftw, cuFFT...).
* @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
* @tparam layout_in The layout of the Chunkspan representing the input discrete function.
* @tparam layout_out The layout of the Chunkspan representing the output discrete function.
*
* @param execSpace The Kokkos::ExecutionSpace on which the FFT is performed.
* @param out The output discrete function, represented as a ChunkSpan storing values on a spectral mesh.
* @param in The input discrete function, represented as a ChunkSpan storing values on a spatial mesh.
* @param kwargs The kwArgs_fft configuring the FFT.
*/
template <
typename Tin,
typename Tout,
Expand Down Expand Up @@ -663,7 +816,27 @@ void fft(
{ddc::FFT_Direction::FORWARD, kwargs.normalization});
}

// iFFT (deduced from the fact that "in" is identified as a function on the Fourier space)
/**
* @brief Perform an inverse Fast Fourier Transform.
*
* Compute the inverse discrete Fourier transform of a spectral function using the specialized implementation for the Kokkos::ExecutionSpace
* of the iFFT algorithm.
*
* @tparam Tin The type of the input elements (Kokkos::complex<float> or Kokkos::complex<double>).
* @tparam Tout The type of the output elements (float, Kokkos::complex<float>, double or Kokkos::complex<double>).
* @tparam DDimX... The parameter pack of the spatial discrete dimensions.
* @tparam DDimFx... The parameter pack of the Fourier discrete dimensions.
* @tparam ExecSpace The type of the Kokkos::ExecutionSpace on which the iFFT is performed. It determines which specialized
* backend is used (ie. fftw, cuFFT...).
* @tparam MemorySpace The type of the Kokkos::MemorySpace on which are stored the input and output discrete functions.
* @tparam layout_in The layout of the Chunkspan representing the input discrete function.
* @tparam layout_out The layout of the Chunkspan representing the output discrete function.
*
* @param execSpace The Kokkos::ExecutionSpace on which the iFFT is performed.
* @param out The output discrete function, represented as a ChunkSpan storing values on a spatial mesh.
* @param in The input discrete function, represented as a ChunkSpan storing values on a spectral mesh.
* @param kwargs The kwArgs_fft configuring the iFFT.
*/
template <
typename Tin,
typename Tout,
Expand Down
Loading