From 5c5ae07b903785122a712632acc51aec42f66e74 Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Mon, 3 Oct 2022 16:21:08 +0000 Subject: [PATCH] bwinline: Forceinline all functions of the bw [ghstack-poisoned] --- .../mem_eff_attention/gemm_kernel_utils.h | 2 +- .../cuda/mem_eff_attention/kernel_backward.h | 49 +++++++++++++------ 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h b/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h index cba6e24b85..614cf2fc7b 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/gemm_kernel_utils.h @@ -118,7 +118,7 @@ struct TypeTraits { }; template -constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { return (n + m - 1) / m; } diff --git a/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h b/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h index c55c7500ec..4bcc66b4fe 100644 --- a/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h +++ b/xformers/components/attention/csrc/cuda/mem_eff_attention/kernel_backward.h @@ -95,7 +95,7 @@ struct AttentionBackwardKernel { int32_t num_batches; bool causal; - __device__ void advance_batches(int32_t batch_id) { + CUTLASS_DEVICE void advance_batches(int32_t batch_id) { constexpr int32_t kAlignLSE = 32; // block size of backward auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; @@ -110,6 +110,23 @@ struct AttentionBackwardKernel { grad_query_ptr += batch_id * head_dim * num_queries; grad_key_ptr += batch_id * head_dim * num_keys; grad_value_ptr += batch_id * head_dim_value * num_keys; + + head_dim = warp_uniform(head_dim); + head_dim_value = warp_uniform(head_dim_value); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); } __host__ dim3 getBlocksGrid() const { @@ -600,7 +617,7 @@ struct AttentionBackwardKernel { typename MatmulGradV::Mma::FragmentC gradV; typename MatmulGradK::Mma::FragmentC gradK; - __device__ __forceinline__ void clear() { + CUTLASS_DEVICE void clear() { gradV.clear(); gradK.clear(); } @@ -620,7 +637,7 @@ struct AttentionBackwardKernel { "value is not correctly aligned"); } - static __device__ void kernel(Params& p_) { + static CUTLASS_DEVICE void kernel(Params& p_) { // Hint to nvcc to store points & tensor shapes in registers // as we use them a lot register const Params p = p_; @@ -657,7 +674,7 @@ struct AttentionBackwardKernel { __syncthreads(); } - OutputFragments output_frags; + OutputFragments register output_frags; int32_t key_start = 0; int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ; for (; key_start < key_end; key_start += kBlockSizeJ) { @@ -694,7 +711,7 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ void loadDi( + static CUTLASS_DEVICE void loadDi( cutlass::Array& di, Params const& p, int32_t query_start) { @@ -709,7 +726,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void processBlockIJ( + static CUTLASS_DEVICE void processBlockIJ( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -717,9 +734,9 @@ struct AttentionBackwardKernel { int32_t key_start) { cutlass::MatrixCoord no_offset{0, 0}; accum_t scale = accum_t(1.0 / std::sqrt(float(p.head_dim))); - int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; - int32_t warp_id = threadIdx.y; - int32_t lane_id = threadIdx.x; + int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x; + int8_t warp_id = warp_uniform(threadIdx.y); + int8_t lane_id = threadIdx.x; __syncthreads(); loadDi(shared_storage.di(), p, query_start); @@ -1258,7 +1275,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void writeFragsToGmem( + static CUTLASS_DEVICE void writeFragsToGmem( SharedStorage& shared_storage, OutputFragments& output_frags, Params const& p, @@ -1291,7 +1308,7 @@ struct AttentionBackwardKernel { } template - static __device__ __forceinline__ void accumulateInGmem( + static CUTLASS_DEVICE void accumulateInGmem( typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, typename MatmulT::Mma::FragmentC const& accum, typename MatmulT::OutputTileIterator output_it, @@ -1333,7 +1350,9 @@ struct AttentionBackwardKernel { } template - static __device__ void computeDelta(Params const& p, int32_t query_start) { + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start) { // Each thread computes one value for Delta // Depending on warp configuration, we might have multiple // threads of the same warp working on the same row @@ -1429,13 +1448,13 @@ struct AttentionBackwardKernel { } } - static __device__ __forceinline__ int8_t get_lane_id() { + static CUTLASS_DEVICE int8_t get_lane_id() { return threadIdx.x; } - static __device__ __forceinline__ int8_t get_warp_id() { + static CUTLASS_DEVICE int8_t get_warp_id() { return threadIdx.y; } - static __device__ __forceinline__ int16_t get_thread_id() { + static CUTLASS_DEVICE int16_t get_thread_id() { return threadIdx.x + threadIdx.y * blockDim.x; } };