Skip to content

Commit

Permalink
make warpReduce work for general types
Browse files Browse the repository at this point in the history
  • Loading branch information
WeiqunZhang committed Jul 22, 2022
1 parent 62419a7 commit ef0d395
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions Src/Base/AMReX_GpuReduce.H
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <AMReX_GpuAtomic.H>
#include <AMReX_GpuUtility.H>
#include <AMReX_Functional.H>
#include <AMReX_TypeTraits.H>

#if !defined(AMREX_USE_CUB) && defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11)
#define AMREX_USE_CUB 1
Expand Down Expand Up @@ -249,15 +250,54 @@ void deviceReduceLogicalOr (int * dest, int source, Gpu::Handler const& h) noexc

#elif defined(AMREX_USE_CUDA) || defined(AMREX_USE_HIP)

namespace detail {

template <typename T>
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
T shuffle_down (T x, int offset) noexcept
{
return AMREX_HIP_OR_CUDA(__shfl_down(x, offset),
__shfl_down_sync(0xffffffff, x, offset));
}

// If other sizeof is needed, we can implement it later.
template <class T, std::enable_if_t<sizeof(T)%sizeof(unsigned int) == 0, int> = 0>
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
T multi_shuffle_down (T x, int offset) noexcept
{
constexpr int nwords = (sizeof(T) + sizeof(unsigned int) - 1) / sizeof(unsigned int);
T y;
auto py = reinterpret_cast<unsigned int*>(&y);
auto px = reinterpret_cast<unsigned int*>(&x);
for (int i = 0; i < nwords; ++i) {
py[i] = shuffle_down(px[i],offset);
}
return y;
}

}

template <int warpSize, typename T, typename F>
struct warpReduce
{
// Not all arithmetic types can be taken by shuffle_down, but it's good enough.
template <class U=T, std::enable_if_t<std::is_arithmetic<U>::value,int> = 0>
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
T operator() (T x) const noexcept
{
for (int offset = warpSize/2; offset > 0; offset /= 2) {
T y = detail::shuffle_down(x, offset);
x = F()(x,y);
}
return x;
}

template <class U=T, std::enable_if_t<!std::is_arithmetic<U>::value,int> = 0>
AMREX_GPU_DEVICE AMREX_FORCE_INLINE
T operator() (T x) const noexcept
{
for (int offset = warpSize/2; offset > 0; offset /= 2) {
AMREX_HIP_OR_CUDA(T y = __shfl_down(x, offset);,
T y = __shfl_down_sync(0xffffffff, x, offset); )
T y = detail::multi_shuffle_down(x, offset);
x = F()(x,y);
}
return x;
Expand Down

0 comments on commit ef0d395

Please sign in to comment.