Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CTOParallelFor with BoxND / add AnyCTO #4109

Merged
merged 7 commits into from
Sep 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 174 additions & 87 deletions Src/Base/AMReX_CTOParallelForImpl.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include <AMReX_BLassert.H>
#include <AMReX_Box.H>
#include <AMReX_Tuple.H>
#include <AMReX_TypeList.H>

#include <array>
#include <type_traits>
Expand All @@ -18,125 +18,212 @@ namespace amrex {

template <int... ctr>
struct CompileTimeOptions {
// TypeList is defined in AMReX_Tuple.H
// TypeList is defined in AMReX_TypeList.H
using list_type = TypeList<std::integral_constant<int, ctr>...>;
};

#if (__cplusplus >= 201703L)

namespace detail
{
template <int MT, typename T, class F, typename... As>
std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T,Box>, bool>
ParallelFor_helper2 (T const& N, F const& f, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options)
{
if (runtime_options == std::array<int,sizeof...(As)>{As::value...}) {
if constexpr (std::is_integral_v<T>) {
ParallelFor<MT>(N, [f] AMREX_GPU_DEVICE (T i) noexcept
{
f(i, As{}...);
});
} else {
ParallelFor<MT>(N, [f] AMREX_GPU_DEVICE (int i, int j, int k) noexcept
{
f(i, j, k, As{}...);
});
}
return true;
} else {
return false;
template<class F, int... ctr>
struct CTOWrapper {
F f;

template<class... Args>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto operator() (Args... args) const noexcept
-> decltype(f(args..., std::integral_constant<int, ctr>{}...)) {
return f(args..., std::integral_constant<int, ctr>{}...);
}
}

template <int MT, typename T, class F, typename... As>
std::enable_if_t<std::is_integral_v<T>, bool>
ParallelFor_helper2 (Box const& box, T ncomp, F const& f, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options)
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
static constexpr
std::array<int, sizeof...(ctr)> GetOptions () noexcept {
return {ctr...};
}
};

template <class L, class F, typename... As>
bool
AnyCTO_helper2 (const L& l, const F& f, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options)
{
if (runtime_options == std::array<int,sizeof...(As)>{As::value...}) {
ParallelFor<MT>(box, ncomp, [f] AMREX_GPU_DEVICE (int i, int j, int k, T n) noexcept
{
f(i, j, k, n, As{}...);
});
l(CTOWrapper<F, As::value...>{f});
return true;
} else {
return false;
}
}

template <int MT, typename T, class F, typename... PPs, typename RO>
std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T,Box>>
ParallelFor_helper1 (T const& N, F const& f, TypeList<PPs...>,
RO const& runtime_options)
{
bool found_option = (false || ... ||
ParallelFor_helper2<MT>(N, f,
PPs{}, runtime_options));
amrex::ignore_unused(found_option);
AMREX_ASSERT(found_option);
}

template <int MT, typename T, class F, typename... PPs, typename RO>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor_helper1 (Box const& box, T ncomp, F const& f, TypeList<PPs...>,
RO const& runtime_options)
template <class L, class F, typename... PPs, typename RO>
void
AnyCTO_helper1 (const L& l, const F& f, TypeList<PPs...>, RO const& runtime_options)
{
bool found_option = (false || ... ||
ParallelFor_helper2<MT>(box, ncomp, f,
PPs{}, runtime_options));
bool found_option = (false || ... || AnyCTO_helper2(l, f, PPs{}, runtime_options));
amrex::ignore_unused(found_option);
AMREX_ASSERT(found_option);
}
}

#endif

template <int MT, typename T, class F, typename... CTOs>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor (TypeList<CTOs...> /*list_of_compile_time_options*/,
/**
* \brief Compile time optimization of kernels with run time options.
*
* This is a generalized version of ParallelFor with CTOs that can support any function that
* takes in one lambda to launch a GPU kernel such as ParallelFor, ParallelForRNG, launch, etc.
* It uses fold expression to generate kernel launches for all combinations
* of the run time options. The kernel function can use constexpr if to
* discard unused code blocks for better run time performance. In the
* example below, the code will be expanded into 4*2=8 normal ParallelForRNGs
* for all combinations of the run time parameters.
\verbatim
int A_runtime_option = ...;
int B_runtime_option = ...;
enum A_options : int { A0, A1, A2, A3 };
enum B_options : int { B0, B1 };
AnyCTO(TypeList<CompileTimeOptions<A0,A1,A2,A3>,
CompileTimeOptions<B0,B1>>{},
{A_runtime_option, B_runtime_option},
[&](auto cto_func){
ParallelForRNG(N, cto_func);
},
[=] AMREX_GPU_DEVICE (int i, const RandomEngine& engine,
auto A_control, auto B_control)
{
...
if constexpr (A_control.value == A0) {
...
} else if constexpr (A_control.value == A1) {
...
} else if constexpr (A_control.value == A2) {
...
} else {
...
}
if constexpr (A_control.value != A3 && B_control.value == B1) {
...
}
...
}
);

constexpr int nthreads_per_block = ...;
int nblocks = ...;
AnyCTO(TypeList<CompileTimeOptions<A0,A1,A2,A3>,
CompileTimeOptions<B0,B1>>{},
{A_runtime_option, B_runtime_option},
[&](auto cto_func){
launch<nthreads_per_block>(nblocks, Gpu::gpuStream(), cto_func);
},
[=] AMREX_GPU_DEVICE (auto A_control, auto B_control){
...
}
);
\endverbatim
* The static member function cto_func.GetOptions() can be used to obtain the runtime_options
* passed into AnyCTO, but at compile time. This enables some advanced use cases,
* such as changing the number of threads per block or the dimensionality of ParallelFor at runtime.
* For the second example -> decltype(void(intvect.size())) is necessary to
* disambiguate IntVectND<1> and int for the first argument of the kernel function.
\verbatim
int nthreads_per_block = ...;
AnyCTO(TypeList<CompileTimeOptions<128,256,512,1024>>{},
{nthreads_per_block},
[&](auto cto_func){
constexpr std::array<int, 1> ctos = cto_func.GetOptions();
constexpr int c_nthreads_per_block = ctos[0];
ParallelFor<c_nthreads_per_block>(N, cto_func);
},
[=] AMREX_GPU_DEVICE (int i, auto){
...
}
);

BoxND<6> box6D = ...;
int dims_needed = ...;
AnyCTO(TypeList<CompileTimeOptions<1,2,3,4,5,6>>{},
{dims_needed},
[&](auto cto_func){
constexpr std::array<int, 1> ctos = cto_func.GetOptions();
constexpr int c_dims_needed = ctos[0];
const auto box = BoxShrink<c_dims_needed>(box6D);
ParallelFor(box, cto_func);
},
[=] AMREX_GPU_DEVICE (auto intvect, auto) -> decltype(void(intvect.size())) {
...
}
);
\endverbatim

* Note that due to a limitation of CUDA's extended device lambda, the
* constexpr if block cannot be the one that captures a variable first.
* If nvcc complains about it, you will have to manually capture it outside
* constexpr if. Alternatively, the constexpr if can be replaced with a regular if.
* Compilers can still perform the same optimizations since the condition is known at compile time.
* The data type for the parameters is int.
*
* \param list_of_compile_time_options list of all possible values of the parameters.
* \param runtime_options the run time parameters.
* \param l a callable object containing a CPU function that launches the provided GPU kernel.
* \param f a callable object containing the GPU kernel with optimizations.
*/
template <class L, class F, typename... CTOs>
void AnyCTO ([[maybe_unused]] TypeList<CTOs...> list_of_compile_time_options,
std::array<int,sizeof...(CTOs)> const& runtime_options,
T N, F&& f)
L&& l, F&& f)
{
#if (__cplusplus >= 201703L)
detail::ParallelFor_helper1<MT>(N, std::forward<F>(f),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
detail::AnyCTO_helper1(std::forward<L>(l), std::forward<F>(f),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
#else
amrex::ignore_unused(N, f, runtime_options);
amrex::ignore_unused(runtime_options, l, f);
static_assert(std::is_integral<F>::value, "This requires C++17");
#endif
}

template <int MT, class F, typename... CTOs>
void ParallelFor (TypeList<CTOs...> /*list_of_compile_time_options*/,
template <int MT, typename T, class F, typename... CTOs>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& runtime_options,
T N, F&& f)
{
AnyCTO(ctos, runtime_options,
[&](auto cto_func){
ParallelFor<MT>(N, cto_func);
},
std::forward<F>(f)
);
}

template <int MT, class F, int dim, typename... CTOs>
void ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& runtime_options,
Box const& box, F&& f)
BoxND<dim> const& box, F&& f)
{
#if (__cplusplus >= 201703L)
detail::ParallelFor_helper1<MT>(box, std::forward<F>(f),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
#else
amrex::ignore_unused(box, f, runtime_options);
static_assert(std::is_integral<F>::value, "This requires C++17");
#endif
AnyCTO(ctos, runtime_options,
[&](auto cto_func){
ParallelFor<MT>(box, cto_func);
},
std::forward<F>(f)
);
}

template <int MT, typename T, class F, typename... CTOs>
template <int MT, typename T, class F, int dim, typename... CTOs>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor (TypeList<CTOs...> /*list_of_compile_time_options*/,
ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& runtime_options,
Box const& box, T ncomp, F&& f)
BoxND<dim> const& box, T ncomp, F&& f)
{
#if (__cplusplus >= 201703L)
detail::ParallelFor_helper1<MT>(box, ncomp, std::forward<F>(f),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
#else
amrex::ignore_unused(box, ncomp, f, runtime_options);
static_assert(std::is_integral<F>::value, "This requires C++17");
#endif
AnyCTO(ctos, runtime_options,
[&](auto cto_func){
ParallelFor<MT>(box, ncomp, cto_func);
},
std::forward<F>(f)
);
}

/**
Expand Down Expand Up @@ -164,7 +251,7 @@ ParallelFor (TypeList<CTOs...> /*list_of_compile_time_options*/,
...
} else if constexpr (A_control.value == A2) {
...
else {
} else {
...
}
if constexpr (A_control.value != A3 && B_control.value == B1) {
Expand Down Expand Up @@ -218,7 +305,7 @@ ParallelFor (TypeList<CTOs...> ctos,
...
} else if constexpr (A_control.value == A2) {
...
else {
} else {
...
}
if constexpr (A_control.value != A3 && B_control.value == B1) {
Expand All @@ -237,10 +324,10 @@ ParallelFor (TypeList<CTOs...> ctos,
* \param box a Box specifying the 3D for loop's range.
* \param f a callable object taking three integers and working on the given cell.
*/
template <class F, typename... CTOs>
template <class F, int dim, typename... CTOs>
void ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& option,
Box const& box, F&& f)
BoxND<dim> const& box, F&& f)
{
ParallelFor<AMREX_GPU_MAX_THREADS>(ctos, option, box, std::forward<F>(f));
}
Expand Down Expand Up @@ -271,7 +358,7 @@ void ParallelFor (TypeList<CTOs...> ctos,
...
} else if constexpr (A_control.value == A2) {
...
else {
} else {
...
}
if constexpr (A_control.value != A3 && B_control.value == B1) {
Expand All @@ -291,11 +378,11 @@ void ParallelFor (TypeList<CTOs...> ctos,
* \param ncomp an integer specifying the range for iteration over components.
* \param f a callable object taking three integers and working on the given cell.
*/
template <typename T, class F, typename... CTOs>
template <typename T, class F, int dim, typename... CTOs>
std::enable_if_t<std::is_integral_v<T>>
ParallelFor (TypeList<CTOs...> ctos,
std::array<int,sizeof...(CTOs)> const& option,
Box const& box, T ncomp, F&& f)
BoxND<dim> const& box, T ncomp, F&& f)
{
ParallelFor<AMREX_GPU_MAX_THREADS>(ctos, option, box, ncomp, std::forward<F>(f));
}
Expand Down