Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions Src/Base/AMReX_CTOParallelForImpl.H
Original file line number Diff line number Diff line change
Expand Up @@ -44,24 +44,33 @@ namespace detail
}
};

template <class L, class F, typename... As>
template <class L, typename... As, class... Fs>
bool
AnyCTO_helper2 (const L& l, const F& f, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options)
AnyCTO_helper2 (const L& l, TypeList<As...>,
std::array<int,sizeof...(As)> const& runtime_options, const Fs&...cto_functs)
{
if (runtime_options == std::array<int,sizeof...(As)>{As::value...}) {
l(CTOWrapper<F, As::value...>{f});
if constexpr (sizeof...(cto_functs) != 0) {
// Apply the CTOWrapper to each function that was given in cto_functs
// and call the CPU function l with all of them
l(CTOWrapper<Fs, As::value...>{cto_functs}...);
} else {
// No functions in cto_functs so we call l directly with the compile time arguments
l(As{}...);
}
return true;
} else {
return false;
}
}

template <class L, class F, typename... PPs, typename RO>
template <class L, typename... PPs, typename RO, class...Fs>
void
AnyCTO_helper1 (const L& l, const F& f, TypeList<PPs...>, RO const& runtime_options)
AnyCTO_helper1 (const L& l, TypeList<PPs...>,
RO const& runtime_options, const Fs&...cto_functs)
{
bool found_option = (false || ... || AnyCTO_helper2(l, f, PPs{}, runtime_options));
bool found_option = (false || ... ||
AnyCTO_helper2(l, PPs{}, runtime_options, cto_functs...));
amrex::ignore_unused(found_option);
AMREX_ASSERT(found_option);
}
Expand Down Expand Up @@ -168,17 +177,18 @@ namespace detail
* \param list_of_compile_time_options list of all possible values of the parameters.
* \param runtime_options the run time parameters.
* \param l a callable object containing a CPU function that launches the provided GPU kernel.
* \param f a callable object containing the GPU kernel with optimizations.
* \param cto_functs a callable object containing the GPU kernel with optimizations.
*/
template <class L, class F, typename... CTOs>
template <class L, class... Fs, typename... CTOs>
void AnyCTO ([[maybe_unused]] TypeList<CTOs...> list_of_compile_time_options,
std::array<int,sizeof...(CTOs)> const& runtime_options,
L&& l, F&& f)
L&& l, Fs&&...cto_functs)
{
#if (__cplusplus >= 201703L)
detail::AnyCTO_helper1(std::forward<L>(l), std::forward<F>(f),
detail::AnyCTO_helper1(std::forward<L>(l),
CartesianProduct(typename CTOs::list_type{}...),
runtime_options);
runtime_options,
std::forward<Fs>(cto_functs)...);
#else
amrex::ignore_unused(runtime_options, l, f);
static_assert(std::is_integral<F>::value, "This requires C++17");
Expand Down
Loading