Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove switches and catch bad values #631

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 29 additions & 31 deletions include/ddc/kernels/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <array>
#include <cstddef>
#include <stdexcept>
#include <type_traits>
#include <utility>

Expand All @@ -22,7 +23,6 @@
#if cufft_AVAIL
#include <functional>
#include <memory>
#include <stdexcept>

#include <cuda_runtime_api.h>
#include <cufft.h>
Expand All @@ -31,7 +31,6 @@
#if hipfft_AVAIL
#include <functional>
#include <memory>
#include <stdexcept>

#include <hip/hip_runtime_api.h>
#include <hipfft/hipfft.h>
Expand Down Expand Up @@ -564,36 +563,35 @@ void impl(

if (kwargs.normalization != ddc::FFT_Normalization::OFF) {
real_type_t<Tout> norm_coef = 1;
switch (kwargs.normalization) {
case ddc::FFT_Normalization::OFF:
break;
EmilyBourne marked this conversation as resolved.
Show resolved Hide resolved
case ddc::FFT_Normalization::FORWARD:
norm_coef = kwargs.direction == ddc::FFT_Direction::FORWARD
? 1. / (ddc::get<DDimX>(mesh.extents()) * ...)
: 1.;
break;
case ddc::FFT_Normalization::BACKWARD:
norm_coef = kwargs.direction == ddc::FFT_Direction::BACKWARD
? 1. / (ddc::get<DDimX>(mesh.extents()) * ...)
: 1.;
break;
case ddc::FFT_Normalization::ORTHO:
if (kwargs.normalization == ddc::FFT_Normalization::FORWARD) {
if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
norm_coef = 1. / (ddc::get<DDimX>(mesh.extents()) * ...);
}
} else if (kwargs.normalization == ddc::FFT_Normalization::BACKWARD) {
if (kwargs.direction == ddc::FFT_Direction::BACKWARD) {
norm_coef = 1. / (ddc::get<DDimX>(mesh.extents()) * ...);
}
} else if (kwargs.normalization == ddc::FFT_Normalization::ORTHO) {
norm_coef = 1. / Kokkos::sqrt((ddc::get<DDimX>(mesh.extents()) * ...));
break;
case ddc::FFT_Normalization::FULL:
norm_coef = kwargs.direction == ddc::FFT_Direction::FORWARD
? (((coordinate(ddc::select<DDimX>(mesh).back())
- coordinate(ddc::select<DDimX>(mesh).front()))
/ (ddc::get<DDimX>(mesh.extents()) - 1)
/ Kokkos::sqrt(2 * Kokkos::numbers::pi))
* ...)
: ((Kokkos::sqrt(2 * Kokkos::numbers::pi)
/ (coordinate(ddc::select<DDimX>(mesh).back())
- coordinate(ddc::select<DDimX>(mesh).front()))
* (ddc::get<DDimX>(mesh.extents()) - 1)
/ ddc::get<DDimX>(mesh.extents()))
* ...);
break;
} else if (kwargs.normalization == ddc::FFT_Normalization::FULL) {
if (kwargs.direction == ddc::FFT_Direction::FORWARD) {
norm_coef
= (((coordinate(ddc::select<DDimX>(mesh).back())
- coordinate(ddc::select<DDimX>(mesh).front()))
/ (ddc::get<DDimX>(mesh.extents()) - 1)
/ Kokkos::sqrt(2 * Kokkos::numbers::pi))
* ...);
} else {
norm_coef
= ((Kokkos::sqrt(2 * Kokkos::numbers::pi)
/ (coordinate(ddc::select<DDimX>(mesh).back())
- coordinate(ddc::select<DDimX>(mesh).front()))
* (ddc::get<DDimX>(mesh.extents()) - 1)
/ ddc::get<DDimX>(mesh.extents()))
* ...);
}
} else {
throw std::runtime_error("ddc::FFT_Normalization not handled");
}

Kokkos::parallel_for(
Expand Down
15 changes: 9 additions & 6 deletions include/ddc/kernels/splines/spline_boundary_conditions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ enum class BoundCond {
**/
static inline std::ostream& operator<<(std::ostream& out, ddc::BoundCond const bc)
{
switch (bc) {
case ddc::BoundCond::PERIODIC:
if (bc == ddc::BoundCond::PERIODIC) {
return out << "PERIODIC";
case ddc::BoundCond::HERMITE:
}

if (bc == ddc::BoundCond::HERMITE) {
return out << "HERMITE";
case ddc::BoundCond::GREVILLE:
}

if (bc == ddc::BoundCond::GREVILLE) {
return out << "GREVILLE";
default:
throw std::runtime_error("ddc::BoundCond not handled");
}

throw std::runtime_error("ddc::BoundCond not handled");
}

/**
Expand Down
85 changes: 28 additions & 57 deletions include/ddc/kernels/splines/spline_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,11 @@ class SplineBuilder
int lower_block_size;
int upper_block_size;
if constexpr (bsplines_type::is_uniform()) {
compute_block_sizes_uniform(lower_block_size, upper_block_size);
upper_block_size = compute_block_sizes_uniform(BcLower, s_nbc_xmin);
lower_block_size = compute_block_sizes_uniform(BcUpper, s_nbc_xmax);
} else {
compute_block_sizes_non_uniform(lower_block_size, upper_block_size);
upper_block_size = compute_block_sizes_non_uniform(BcLower, s_nbc_xmin);
lower_block_size = compute_block_sizes_non_uniform(BcUpper, s_nbc_xmax);
}
allocate_matrix(
lower_block_size,
Expand Down Expand Up @@ -452,9 +454,9 @@ class SplineBuilder
quadrature_coefficients() const;

private:
void compute_block_sizes_uniform(int& lower_block_size, int& upper_block_size) const;
static int compute_block_sizes_uniform(ddc::BoundCond bound_cond, int nbc);

void compute_block_sizes_non_uniform(int& lower_block_size, int& upper_block_size) const;
static int compute_block_sizes_non_uniform(ddc::BoundCond bound_cond, int nbc);

void allocate_matrix(
int lower_block_size,
Expand Down Expand Up @@ -518,42 +520,29 @@ template <
ddc::BoundCond BcUpper,
SplineSolver Solver,
class... IDimX>
void SplineBuilder<
int SplineBuilder<
ExecSpace,
MemorySpace,
BSplines,
InterpolationDDim,
BcLower,
BcUpper,
Solver,
IDimX...>::compute_block_sizes_uniform(int& lower_block_size, int& upper_block_size) const
IDimX...>::compute_block_sizes_uniform(ddc::BoundCond const bound_cond, int const nbc)
{
switch (BcLower) {
case ddc::BoundCond::PERIODIC:
upper_block_size = (bsplines_type::degree()) / 2;
break;
case ddc::BoundCond::HERMITE:
upper_block_size = s_nbc_xmin;
break;
case ddc::BoundCond::GREVILLE:
upper_block_size = bsplines_type::degree() - 1;
break;
default:
throw std::runtime_error("ddc::BoundCond not handled");
if (bound_cond == ddc::BoundCond::PERIODIC) {
return static_cast<int>(bsplines_type::degree()) / 2;
}
switch (BcUpper) {
case ddc::BoundCond::PERIODIC:
lower_block_size = (bsplines_type::degree()) / 2;
break;
case ddc::BoundCond::HERMITE:
lower_block_size = s_nbc_xmax;
break;
case ddc::BoundCond::GREVILLE:
lower_block_size = bsplines_type::degree() - 1;
break;
default:
throw std::runtime_error("ddc::BoundCond not handled");

if (bound_cond == ddc::BoundCond::HERMITE) {
return nbc;
}

if (bound_cond == ddc::BoundCond::GREVILLE) {
return static_cast<int>(bsplines_type::degree()) - 1;
}

throw std::runtime_error("ddc::BoundCond not handled");
}

template <
Expand All @@ -565,43 +554,25 @@ template <
ddc::BoundCond BcUpper,
SplineSolver Solver,
class... IDimX>
void SplineBuilder<
int SplineBuilder<
ExecSpace,
MemorySpace,
BSplines,
InterpolationDDim,
BcLower,
BcUpper,
Solver,
IDimX...>::compute_block_sizes_non_uniform(int& lower_block_size, int& upper_block_size)
const
IDimX...>::compute_block_sizes_non_uniform(ddc::BoundCond const bound_cond, int const nbc)
{
switch (BcLower) {
case ddc::BoundCond::PERIODIC:
upper_block_size = bsplines_type::degree() - 1;
break;
case ddc::BoundCond::HERMITE:
upper_block_size = s_nbc_xmin + 1;
break;
case ddc::BoundCond::GREVILLE:
upper_block_size = bsplines_type::degree() - 1;
break;
default:
throw std::runtime_error("ddc::BoundCond not handled");
if (bound_cond == ddc::BoundCond::PERIODIC || bound_cond == ddc::BoundCond::GREVILLE) {
return static_cast<int>(bsplines_type::degree()) - 1;
}
switch (BcUpper) {
case ddc::BoundCond::PERIODIC:
lower_block_size = bsplines_type::degree() - 1;
break;
case ddc::BoundCond::HERMITE:
lower_block_size = s_nbc_xmax + 1;
break;
case ddc::BoundCond::GREVILLE:
lower_block_size = bsplines_type::degree() - 1;
break;
default:
throw std::runtime_error("ddc::BoundCond not handled");

if (bound_cond == ddc::BoundCond::HERMITE) {
return nbc + 1;
}

throw std::runtime_error("ddc::BoundCond not handled");
}

template <
Expand Down
19 changes: 8 additions & 11 deletions tests/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: MIT

#include <cstddef>
#include <stdexcept>
#include <type_traits>

#include <ddc/ddc.hpp>
Expand Down Expand Up @@ -197,27 +198,23 @@ static void test_fft_norm(ddc::FFT_Normalization const norm)

double Ff0_expected;
double FFf_expected;
switch (norm) {
case ddc::FFT_Normalization::OFF:
if (norm == ddc::FFT_Normalization::OFF) {
Ff0_expected = f_sum;
FFf_expected = f_sum;
break;
case ddc::FFT_Normalization::FORWARD:
} else if (norm == ddc::FFT_Normalization::FORWARD) {
Ff0_expected = 1;
FFf_expected = 1;
break;
case ddc::FFT_Normalization::BACKWARD:
} else if (norm == ddc::FFT_Normalization::BACKWARD) {
Ff0_expected = f_sum;
FFf_expected = 1;
break;
case ddc::FFT_Normalization::ORTHO:
} else if (norm == ddc::FFT_Normalization::ORTHO) {
Ff0_expected = Kokkos::sqrt(f_sum);
FFf_expected = 1;
break;
case ddc::FFT_Normalization::FULL:
} else if (norm == ddc::FFT_Normalization::FULL) {
Ff0_expected = 1 / Kokkos::sqrt(2 * Kokkos::numbers::pi);
FFf_expected = 1;
break;
} else {
throw std::runtime_error("ddc::FFT_Normalization not handled");
}

double const epsilon = 1e-6;
Expand Down