Skip to content

Commit

Permalink
Refactor call_f in ParticleTransformation and WriteBinaryParticleData…
Browse files Browse the repository at this point in the history
… to use constexpr if (#3448)

This results in much easier-to-interpret error messages, useful in
finding #3449 and #3450

Co-authored-by: Weiqun Zhang <WeiqunZhang@lbl.gov>
  • Loading branch information
atmyers and WeiqunZhang authored Jul 26, 2023
1 parent a3b69fa commit 8aa7dc6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 45 deletions.
37 changes: 26 additions & 11 deletions Src/Particle/AMReX_ParticleTransformation.H
Original file line number Diff line number Diff line change
Expand Up @@ -379,10 +379,10 @@ int filterParticles (DstTile& dst, const SrcTile& src, Pred&& p) noexcept
* \param n the number of particles to apply the operation to
*
*/
template <typename DstTile, typename SrcTile, typename Pred, typename Index, typename N>
auto filterParticles (DstTile& dst, const SrcTile& src, Pred&& p,
Index src_start, Index dst_start, N n) noexcept
-> decltype(Index(particle_detail::call_f(p, typename SrcTile::ConstParticleTileDataType(), int{}, RandomEngine{})))
template <typename DstTile, typename SrcTile, typename Pred, typename Index, typename N,
std::enable_if_t<!std::is_pointer_v<std::decay_t<Pred>>,Index> nvccfoo = 0>
Index filterParticles (DstTile& dst, const SrcTile& src, Pred&& p,
Index src_start, Index dst_start, N n) noexcept
{
Gpu::DeviceVector<Index> mask(n);

Expand All @@ -392,7 +392,12 @@ auto filterParticles (DstTile& dst, const SrcTile& src, Pred&& p,
amrex::ParallelForRNG(n,
[=] AMREX_GPU_DEVICE (int i, amrex::RandomEngine const& engine) noexcept
{
p_mask[i] = particle_detail::call_f(p, src_data, src_start+i, engine);
amrex::ignore_unused(p, p_mask, src_data, src_start, engine);
if constexpr (IsCallable<Pred,decltype(src_data),Index,RandomEngine>::value) {
p_mask[i] = p(src_data, src_start+i, engine);
} else {
p_mask[i] = p(src_data, src_start+i);
}
});
return filterParticles(dst, src, mask.dataPtr(), src_start, dst_start, n);
}
Expand Down Expand Up @@ -558,7 +563,12 @@ int filterAndTransformParticles (DstTile1& dst1, DstTile2& dst2, const SrcTile&
amrex::ParallelForRNG(np,
[=] AMREX_GPU_DEVICE (int i, amrex::RandomEngine const& engine) noexcept
{
p_mask[i] = particle_detail::call_f(p, src_data, i, engine);
amrex::ignore_unused(p, p_mask, src_data, engine);
if constexpr (IsCallable<Pred,decltype(src_data),int,RandomEngine>::value) {
p_mask[i] = p(src_data, i, engine);
} else {
p_mask[i] = p(src_data, i);
}
});
return filterAndTransformParticles(dst1, dst2, src, mask.dataPtr(), std::forward<F>(f));
}
Expand All @@ -582,10 +592,10 @@ int filterAndTransformParticles (DstTile1& dst1, DstTile2& dst2, const SrcTile&
*
*/

template <typename DstTile, typename SrcTile, typename Pred, typename F, typename Index>
auto filterAndTransformParticles (DstTile& dst, const SrcTile& src, Pred&& p, F&& f,
Index src_start, Index dst_start) noexcept
-> decltype(Index(particle_detail::call_f(p, typename SrcTile::ConstParticleTileDataType(), int{}, RandomEngine{})))
template <typename DstTile, typename SrcTile, typename Pred, typename F, typename Index,
std::enable_if_t<!std::is_pointer_v<std::decay_t<Pred>>,Index> nvccfoo = 0>
Index filterAndTransformParticles (DstTile& dst, const SrcTile& src, Pred&& p, F&& f,
Index src_start, Index dst_start) noexcept
{
auto np = src.numParticles();
Gpu::DeviceVector<Index> mask(np);
Expand All @@ -596,7 +606,12 @@ auto filterAndTransformParticles (DstTile& dst, const SrcTile& src, Pred&& p, F&
amrex::ParallelForRNG(np,
[=] AMREX_GPU_DEVICE (int i, amrex::RandomEngine const& engine) noexcept
{
p_mask[i] = particle_detail::call_f(p, src_data, src_start+i, engine);
amrex::ignore_unused(p, p_mask, src_data, src_start, engine);
if constexpr (IsCallable<Pred,decltype(src_data),Index,RandomEngine>::value) {
p_mask[i] = p(src_data, src_start+i, engine);
} else {
p_mask[i] = p(src_data, src_start+i);
}
});
return filterAndTransformParticles(dst, src, mask.dataPtr(), std::forward<F>(f), src_start, dst_start);
}
Expand Down
32 changes: 0 additions & 32 deletions Src/Particle/AMReX_ParticleUtil.H
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,6 @@ namespace amrex

namespace particle_detail {

template <typename F, typename P>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f, P const& p, amrex::RandomEngine const& engine) noexcept
-> decltype(f(P{},RandomEngine{}))
{
return f(p,engine);
}

template <typename F, typename P>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f, P const& p, amrex::RandomEngine const&) noexcept
-> decltype(f(P{}))
{
return f(p);
}

template <typename F, typename SrcData, typename N>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f, SrcData const& src, N i, amrex::RandomEngine const& engine) noexcept
-> decltype(f(SrcData{},N{},RandomEngine{}))
{
return f(src,i,engine);
}

template <typename F, typename SrcData, typename N>
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
auto call_f (F const& f, SrcData const& src, N i, amrex::RandomEngine const&) noexcept
-> decltype(f(SrcData{},N{}))
{
return f(src,i);
}

// The next several functions are used by ParticleReduce

// Lambda takes a Particle
Expand Down
13 changes: 11 additions & 2 deletions Src/Particle/AMReX_WriteBinaryParticleData.H
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ fillFlags (Container<int, Allocator>& pflags, const PTile& ptile, F&& f)
[=] AMREX_GPU_DEVICE (int k, amrex::RandomEngine const& engine) noexcept
{
const auto p = ptd.getSuperParticle(k);
flag_ptr[k] = particle_detail::call_f(f,p,engine);
amrex::ignore_unused(flag_ptr, f, engine);
if constexpr (IsCallable<F,decltype(p),RandomEngine>::value) {
flag_ptr[k] = f(p,engine);
} else {
flag_ptr[k] = f(p);
}
});
}

Expand All @@ -59,7 +64,11 @@ fillFlags (Container<int, Allocator>& pflags, const PTile& ptile, F&& f)
auto flag_ptr = pflags.data();
for (int k = 0; k < np; ++k) {
const auto p = ptd.getSuperParticle(k);
flag_ptr[k] = particle_detail::call_f(f,p,RandomEngine{});
if constexpr (IsCallable<F,decltype(p),RandomEngine>::value) {
flag_ptr[k] = f(p,RandomEngine{});
} else {
flag_ptr[k] = f(p);
}
}
}

Expand Down

0 comments on commit 8aa7dc6

Please sign in to comment.