From a89c6e504f3e6a0e72a2ded45db92d9416bedc38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=82=85=E5=89=91=E5=AF=92?= Date: Thu, 19 Oct 2023 14:11:35 +0800 Subject: [PATCH] Optim cinn block_reduce (#58196) * optim cinn block_reduce * fix bugs * simplify code * replace tree reduce with butterfly reduce when active mask is 0x1f * fix bugs * fix sync bugs * remove shared memory when blockdim less than 32 --- .../runtime/cuda/cinn_cuda_runtime_source.cuh | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh b/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh index a7e4dc6e1de1a3..aef8907b81a431 100644 --- a/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh +++ b/paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh @@ -474,11 +474,11 @@ __device__ inline bool cinn_any(const bool left, const bool right) { return left tmp_val = __shfl_sync(mask, tmp_val, 0, 32); \ return tmp_val; \ } else { \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 16, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 8, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 4, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 2, 32)); \ - tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_down_sync(mask, tmp_val, 1, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 16, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 8, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 4, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 2, 32)); \ + tmp_val = cinn_##REDUCE_TYPE(tmp_val, __shfl_xor_sync(mask, tmp_val, 1, 32)); \ return tmp_val; \ } \ } @@ -530,25 +530,22 @@ __device__ inline float cinn_warp_reduce_avg_fp32(const float *buf, int offset, #define CINN_BLOCK_REDUCE_INTERNAL_IMPL(TYPE, value, init_value, cinn_warp_shuffle_internal) \ int warp_id = threadIdx.x / 32; \ - __shared__ TYPE tmp[32]; \ - if (warp_id == 0) { \ - tmp[threadIdx.x] = init_value; \ - } \ TYPE tmp_val = cinn_warp_shuffle_internal(value); \ if (blockDim.x <= 32) { \ return tmp_val; \ } \ + __shared__ TYPE tmp[32]; \ + if (warp_id == 0) { \ + tmp[threadIdx.x] = init_value; \ + } \ __syncthreads(); \ - if (threadIdx.x % 32 == 0) { \ + if ((threadIdx.x & 31) == 0) { \ tmp[warp_id] = tmp_val; \ } \ __syncthreads(); \ if (warp_id == 0) { \ tmp_val = tmp[threadIdx.x]; \ - tmp_val = cinn_warp_shuffle_internal(tmp_val); \ - if (threadIdx.x == 0) { \ - tmp[0] = tmp_val; \ - } \ + tmp[threadIdx.x] = cinn_warp_shuffle_internal(tmp_val); \ } \ __syncthreads(); \ return tmp[0];