Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnyCTO with arbitrary number of functions #4135

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions Src/Base/AMReX_CTOParallelForImpl.H
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,33 @@ namespace detail
}
};

template <class L, class F, typename... As>
template <class L, typename... As, class... Fs>
bool
AnyCTO_helper2 (const L& l, const F& f, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options)
AnyCTO_helper2 (const L& l, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options, const Fs&...cto_functs)
{
if (runtime_options == std::array<int,sizeof...(As)>{As::value...}) {
l(CTOWrapper<F, As::value...>{f});
if constexpr (sizeof...(cto_functs) != 0) {
// Apply the CTOWrapper to each function that was given in cto_functs
// and call the CPU function l with all of them
l(CTOWrapper<Fs, As::value...>{cto_functs}...);
} else {
// No functions in cto_functs so we call l directly with the compile time arguments
l(As{}...);
}
return true;
} else {
return false;
}
}

template <class L, class F, typename... PPs, typename RO>
template <class L, typename... PPs, typename RO, class...Fs>
void
AnyCTO_helper1 (const L& l, const F& f, TypeList<PPs...>, RO const& runtime_options)
AnyCTO_helper1 (const L& l, TypeList<PPs...>,
RO const& runtime_options, const Fs&...cto_functs)
{
bool found_option = (false || ... || AnyCTO_helper2(l, f, PPs{}, runtime_options));
bool found_option = (false || ... ||
AnyCTO_helper2(l, PPs{}, runtime_options, cto_functs...));
amrex::ignore_unused(found_option);
AMREX_ASSERT(found_option);
}
Expand Down Expand Up @@ -168,17 +177,18 @@ namespace detail
* \param list_of_compile_time_options list of all possible values of the parameters.
* \param runtime_options the run time parameters.
* \param l a callable object containing a CPU function that launches the provided GPU kernel.
* \param f a callable object containing the GPU kernel with optimizations.
* \param cto_functs a callable object containing the GPU kernel with optimizations.
*/
template <class L, class F, typename... CTOs>
template <class L, class... Fs, typename... CTOs>
void AnyCTO ([[maybe_unused]] TypeList<CTOs...> list_of_compile_time_options,
std::array<int,sizeof...(CTOs)> const& runtime_options,
L&& l, F&& f)
L&& l, Fs&&...cto_functs)
{
#if (__cplusplus >= 201703L)
detail::AnyCTO_helper1(std::forward<L>(l), std::forward<F>(f),
detail::AnyCTO_helper1(std::forward<L>(l),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
runtime_options,
std::forward<Fs>(cto_functs)...);
#else
amrex::ignore_unused(runtime_options, l, f);
static_assert(std::is_integral<F>::value, "This requires C++17");
Expand Down
Loading