Skip to content

Commit

Permalink
Remove hard-coded constants and use some sputnik helpers
Browse files Browse the repository at this point in the history
THe use of Dot makes it 2.5% faster already
  • Loading branch information
fmassa committed Apr 5, 2022
1 parent b150cf4 commit cec04e9
Showing 1 changed file with 15 additions and 18 deletions.
33 changes: 15 additions & 18 deletions xformers/components/attention/csrc/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "sputnik/vector_utils.h"


namespace {

Expand All @@ -29,19 +31,16 @@ __global__ void attention_kernel(
at::PackedTensorAccessor<scalar_t, 3> key,
at::PackedTensorAccessor<scalar_t, 3> value
) {
//constexpr int64_t BLOCK = 32;
//constexpr int64_t BLOCK2 = 2;
constexpr int kVecSize = sizeof(float4) / sizeof(float);
int64_t K = query.size(2);
int64_t B = query.size(0);
int64_t M = query.size(1);
int64_t N = key.size(1);

int64_t i = blockIdx.y;
//int64_t j = blockIdx.x;
int64_t j = blockIdx.x * (blockDim.y * BLOCK2) + threadIdx.y * BLOCK2;

{{
//auto aar = query[i][j].data();
float4* aar[BLOCK2];
float4* oo[BLOCK2];
float4 buffer[BLOCK2][8] = {0}; // TODO == K / 4
Expand All @@ -55,18 +54,18 @@ __global__ void attention_kernel(
for (int64_t l = threadIdx.x * BLOCK; l < N; l+=BLOCK * blockDim.x) {
auto bar = reinterpret_cast<float4 *>(key[i][l].data());
scalar_t si[BLOCK2][BLOCK] = {0};
for (int64_t k = 0; k < K / 4; k+=1) {
for (int64_t k = 0; k < K / kVecSize; k+=1) {
float4 aaar[BLOCK2];
#pragma unroll
for (int64_t rr = 0; rr < BLOCK2; rr++) {
aaar[rr] = __ldg(aar[rr] + k);
}
#pragma unroll
for (int64_t rr = 0; rr < BLOCK; rr++) {
float4 bbb = bar[k + K / 4 * rr];
float4 bbb = bar[k + K / kVecSize * rr];
#pragma unroll
for (int64_t rr2 = 0; rr2 < BLOCK2; rr2++) {
si[rr2][rr] += aaar[rr2].x * bbb.x + aaar[rr2].y * bbb.y + aaar[rr2].z * bbb.z + aaar[rr2].w * bbb.w;
sputnik::VectorCompute<float4>::Dot(aaar[rr2], bbb, &si[rr2][rr]);
}
}
}
Expand Down Expand Up @@ -96,7 +95,7 @@ __global__ void attention_kernel(
for (int64_t rr = 0; rr < BLOCK; rr++)
s_delta[rr2][rr] = std::exp(si[rr2][rr] - m_i[rr2]);

for (int64_t k = 0; k < K/4; k+=1) {
for (int64_t k = 0; k < K/kVecSize; k+=1) {
#pragma unroll
for (int64_t rr2 = 0; rr2 < BLOCK2; rr2++) {
buffer[rr2][k].x *= m_delta[rr2];
Expand All @@ -106,14 +105,11 @@ __global__ void attention_kernel(
}
#pragma unroll
for (int64_t rr = 0; rr < BLOCK; rr++) {
float4 tmp2 = vi[k + K / 4 * rr];
float4 tmp2 = vi[k + K / kVecSize * rr];

#pragma unroll
for (int64_t rr2 = 0; rr2 < BLOCK2; rr2++) {
buffer[rr2][k].x += tmp2.x * s_delta[rr2][rr];
buffer[rr2][k].y += tmp2.y * s_delta[rr2][rr];
buffer[rr2][k].z += tmp2.z * s_delta[rr2][rr];
buffer[rr2][k].w += tmp2.w * s_delta[rr2][rr];
sputnik::VectorCompute<float4>::FMA(s_delta[rr2][rr], tmp2, &buffer[rr2][k]);
}
}
}
Expand Down Expand Up @@ -141,7 +137,7 @@ __global__ void attention_kernel(
s_delta += __shfl_xor_sync(0xffffffff, s_delta, stride, 4);
}
s_prime[rr] = s_delta;
for (int64_t k = 0; k < K / 4; k+=1) {
for (int64_t k = 0; k < K / kVecSize; k+=1) {
float4 tmp = buffer[rr][k];
tmp.x *= m_delta;
tmp.y *= m_delta;
Expand All @@ -158,7 +154,7 @@ __global__ void attention_kernel(
}
}

for (int64_t k = threadIdx.x; k < K / 4; k+=blockDim.x) {
for (int64_t k = threadIdx.x; k < K / kVecSize; k+=blockDim.x) {
float4 tmp;

#pragma unroll
Expand Down Expand Up @@ -231,8 +227,9 @@ at::Tensor attention(
dim3 block(4, 16);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES(
query.scalar_type(), "attention_kernel", [&] {
using scalar_t = float;
//AT_DISPATCH_FLOATING_TYPES(
//query.scalar_type(), "attention_kernel", [&] {
attention_kernel<scalar_t><<<grid, block, 0, stream>>>(
res.packed_accessor<scalar_t, 3>(),
query.packed_accessor<scalar_t, 3>(),
Expand All @@ -241,7 +238,7 @@ at::Tensor attention(
//buffer.accessor<scalar_t, 3>()
//idxs.accessor<int64_t, 2>()
);
});
//});

return res;
}
Expand Down

0 comments on commit cec04e9

Please sign in to comment.