diff --git a/xformers/components/attention/csrc/cuda/attention.cu b/xformers/components/attention/csrc/cuda/attention.cu index a4a14d25f0..3e8d7c8236 100644 --- a/xformers/components/attention/csrc/cuda/attention.cu +++ b/xformers/components/attention/csrc/cuda/attention.cu @@ -8,6 +8,8 @@ #include #include +#include "sputnik/vector_utils.h" + namespace { @@ -29,19 +31,16 @@ __global__ void attention_kernel( at::PackedTensorAccessor key, at::PackedTensorAccessor value ) { - //constexpr int64_t BLOCK = 32; - //constexpr int64_t BLOCK2 = 2; + constexpr int kVecSize = sizeof(float4) / sizeof(float); int64_t K = query.size(2); int64_t B = query.size(0); int64_t M = query.size(1); int64_t N = key.size(1); int64_t i = blockIdx.y; - //int64_t j = blockIdx.x; int64_t j = blockIdx.x * (blockDim.y * BLOCK2) + threadIdx.y * BLOCK2; {{ - //auto aar = query[i][j].data(); float4* aar[BLOCK2]; float4* oo[BLOCK2]; float4 buffer[BLOCK2][8] = {0}; // TODO == K / 4 @@ -55,7 +54,7 @@ __global__ void attention_kernel( for (int64_t l = threadIdx.x * BLOCK; l < N; l+=BLOCK * blockDim.x) { auto bar = reinterpret_cast(key[i][l].data()); scalar_t si[BLOCK2][BLOCK] = {0}; - for (int64_t k = 0; k < K / 4; k+=1) { + for (int64_t k = 0; k < K / kVecSize; k+=1) { float4 aaar[BLOCK2]; #pragma unroll for (int64_t rr = 0; rr < BLOCK2; rr++) { @@ -63,10 +62,10 @@ __global__ void attention_kernel( } #pragma unroll for (int64_t rr = 0; rr < BLOCK; rr++) { - float4 bbb = bar[k + K / 4 * rr]; + float4 bbb = bar[k + K / kVecSize * rr]; #pragma unroll for (int64_t rr2 = 0; rr2 < BLOCK2; rr2++) { - si[rr2][rr] += aaar[rr2].x * bbb.x + aaar[rr2].y * bbb.y + aaar[rr2].z * bbb.z + aaar[rr2].w * bbb.w; + sputnik::VectorCompute::Dot(aaar[rr2], bbb, &si[rr2][rr]); } } } @@ -96,7 +95,7 @@ __global__ void attention_kernel( for (int64_t rr = 0; rr < BLOCK; rr++) s_delta[rr2][rr] = std::exp(si[rr2][rr] - m_i[rr2]); - for (int64_t k = 0; k < K/4; k+=1) { + for (int64_t k = 0; k < K/kVecSize; k+=1) { #pragma unroll for (int64_t rr2 = 0; rr2 < BLOCK2; rr2++) { buffer[rr2][k].x *= m_delta[rr2]; @@ -106,14 +105,11 @@ __global__ void attention_kernel( } #pragma unroll for (int64_t rr = 0; rr < BLOCK; rr++) { - float4 tmp2 = vi[k + K / 4 * rr]; + float4 tmp2 = vi[k + K / kVecSize * rr]; #pragma unroll for (int64_t rr2 = 0; rr2 < BLOCK2; rr2++) { - buffer[rr2][k].x += tmp2.x * s_delta[rr2][rr]; - buffer[rr2][k].y += tmp2.y * s_delta[rr2][rr]; - buffer[rr2][k].z += tmp2.z * s_delta[rr2][rr]; - buffer[rr2][k].w += tmp2.w * s_delta[rr2][rr]; + sputnik::VectorCompute::FMA(s_delta[rr2][rr], tmp2, &buffer[rr2][k]); } } } @@ -141,7 +137,7 @@ __global__ void attention_kernel( s_delta += __shfl_xor_sync(0xffffffff, s_delta, stride, 4); } s_prime[rr] = s_delta; - for (int64_t k = 0; k < K / 4; k+=1) { + for (int64_t k = 0; k < K / kVecSize; k+=1) { float4 tmp = buffer[rr][k]; tmp.x *= m_delta; tmp.y *= m_delta; @@ -158,7 +154,7 @@ __global__ void attention_kernel( } } - for (int64_t k = threadIdx.x; k < K / 4; k+=blockDim.x) { + for (int64_t k = threadIdx.x; k < K / kVecSize; k+=blockDim.x) { float4 tmp; #pragma unroll @@ -231,8 +227,9 @@ at::Tensor attention( dim3 block(4, 16); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES( - query.scalar_type(), "attention_kernel", [&] { + using scalar_t = float; + //AT_DISPATCH_FLOATING_TYPES( + //query.scalar_type(), "attention_kernel", [&] { attention_kernel<<>>( res.packed_accessor(), query.packed_accessor(), @@ -241,7 +238,7 @@ at::Tensor attention( //buffer.accessor() //idxs.accessor() ); - }); + //}); return res; }