Skip to content

Commit

Permalink
Workaround: NVCC 11 Nested Device Lambda
Browse files Browse the repository at this point in the history
CUDA builds for NVCC 11.0.2 fail with
```
  error #3206-D: An extended __device__ lambda cannot be defined inside a generic lambda expression("operator()").
```
if we try to nest out `std::visit` and `ParallelFor` device lambdas.
So, as a work-around, we isolate the single particle push out into a
C++ functor.

Wow, 2010 feels.
  • Loading branch information
ax3l committed Dec 5, 2021
1 parent d94a212 commit 97711fc
Showing 1 changed file with 68 additions and 12 deletions.
80 changes: 68 additions & 12 deletions src/particles/Push.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,70 @@

namespace impactx
{
namespace detail
{
/** Push a single particle through an element
*
* Note: we usually would just write a C++ lambda below in ParallelFor. But, due to restrictions
* in NVCC as of 11.0.2, we cannot write a lambda in a lambda as we also std::visit the element
* types of our beamline_element list.
* error #3206-D: An extended __device__ lambda cannot be defined inside a generic lambda expression("operator()").
* Thus, we fall back to writing a C++ functor here, instead of nesting two lambdas.
*
* @tparam T_Element This can be a \see Drift, \see Quad, \see Sbend, etc.
*/
template <typename T_Element>
struct PushSingleParticle
{
using PType = ImpactXParticleContainer::ParticleType;

/** Constructor taking in pointers to particle data
*
* @param element the beamline element to push through
* @param aos_ptr the array-of-struct with position and ids
* @param part_px the array to the particle momentum (x)
* @param part_py the array to the particle momentum (y)
* @param part_pt the array to the particle momentum (t)
*/
PushSingleParticle (T_Element element,
PType* aos_ptr,
amrex::ParticleReal* part_px,
amrex::ParticleReal* part_py,
amrex::ParticleReal* part_pt)
: m_element(element), m_aos_ptr(aos_ptr),
m_part_px(part_px), m_part_py(part_py), m_part_pt(part_pt)
{
}

PushSingleParticle () = delete;
PushSingleParticle (PushSingleParticle const &) = default;
PushSingleParticle (PushSingleParticle &&) = default;
~PushSingleParticle () = default;

AMREX_GPU_DEVICE AMREX_FORCE_INLINE
void
operator() (long i) const
{
// access AoS data such as positions and cpu/id
PType& p = m_aos_ptr[i];

// access SoA Real data
amrex::ParticleReal & px = m_part_px[i];
amrex::ParticleReal & py = m_part_py[i];
amrex::ParticleReal & pt = m_part_pt[i];

m_element(p, px, py, pt);
}

private:
T_Element const m_element;
PType* const AMREX_RESTRICT m_aos_ptr;
amrex::ParticleReal* const AMREX_RESTRICT m_part_px;
amrex::ParticleReal* const AMREX_RESTRICT m_part_py;
amrex::ParticleReal* const AMREX_RESTRICT m_part_pt;
};
} // namespace detail

void Push (ImpactXParticleContainer & pc,
std::list<KnownElements> const & beamline_elements)
{
Expand Down Expand Up @@ -49,19 +113,11 @@ namespace impactx
for (auto & element_variant : beamline_elements) {
// here we just access the element by its respective type
std::visit([=](auto&& element) {
detail::PushSingleParticle<decltype(element)> const pushSingleParticle(
element, aos_ptr, part_px, part_py, part_pt);

// loop over particles in the box
amrex::ParallelFor( np, [=] AMREX_GPU_DEVICE (long i)
{
// access AoS data such as positions and cpu/id
PType& p = aos_ptr[i];

// access SoA Real data
amrex::ParticleReal & px = part_px[i];
amrex::ParticleReal & py = part_py[i];
amrex::ParticleReal & pt = part_pt[i];

element(p, px, py, pt);
});
amrex::ParallelFor(np, pushSingleParticle);
}, element_variant);
}; // end loop over all beamline elements

Expand Down

0 comments on commit 97711fc

Please sign in to comment.