From ef0d395c76c47c8a6d6017a4282f2e7ee7a65e3d Mon Sep 17 00:00:00 2001 From: Weiqun Zhang Date: Thu, 21 Jul 2022 18:14:25 -0700 Subject: [PATCH] make warpReduce work for general types --- Src/Base/AMReX_GpuReduce.H | 44 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/Src/Base/AMReX_GpuReduce.H b/Src/Base/AMReX_GpuReduce.H index 9b48138940c..3907ca385f6 100644 --- a/Src/Base/AMReX_GpuReduce.H +++ b/Src/Base/AMReX_GpuReduce.H @@ -8,6 +8,7 @@ #include #include #include +#include #if !defined(AMREX_USE_CUB) && defined(AMREX_USE_CUDA) && defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 11) #define AMREX_USE_CUB 1 @@ -249,15 +250,54 @@ void deviceReduceLogicalOr (int * dest, int source, Gpu::Handler const& h) noexc #elif defined(AMREX_USE_CUDA) || defined(AMREX_USE_HIP) +namespace detail { + +template +AMREX_GPU_DEVICE AMREX_FORCE_INLINE +T shuffle_down (T x, int offset) noexcept +{ + return AMREX_HIP_OR_CUDA(__shfl_down(x, offset), + __shfl_down_sync(0xffffffff, x, offset)); +} + +// If other sizeof is needed, we can implement it later. +template = 0> +AMREX_GPU_DEVICE AMREX_FORCE_INLINE +T multi_shuffle_down (T x, int offset) noexcept +{ + constexpr int nwords = (sizeof(T) + sizeof(unsigned int) - 1) / sizeof(unsigned int); + T y; + auto py = reinterpret_cast(&y); + auto px = reinterpret_cast(&x); + for (int i = 0; i < nwords; ++i) { + py[i] = shuffle_down(px[i],offset); + } + return y; +} + +} + template struct warpReduce { + // Not all arithmetic types can be taken by shuffle_down, but it's good enough. + template ::value,int> = 0> + AMREX_GPU_DEVICE AMREX_FORCE_INLINE + T operator() (T x) const noexcept + { + for (int offset = warpSize/2; offset > 0; offset /= 2) { + T y = detail::shuffle_down(x, offset); + x = F()(x,y); + } + return x; + } + + template ::value,int> = 0> AMREX_GPU_DEVICE AMREX_FORCE_INLINE T operator() (T x) const noexcept { for (int offset = warpSize/2; offset > 0; offset /= 2) { - AMREX_HIP_OR_CUDA(T y = __shfl_down(x, offset);, - T y = __shfl_down_sync(0xffffffff, x, offset); ) + T y = detail::multi_shuffle_down(x, offset); x = F()(x,y); } return x;