-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Flash Attention implementation (forward + backward) #1
Conversation
Currently I don't think llama.cpp has good support for training at all. I want to work on it at some point but implementing a backwards pass currently doesn't make sense I think. |
In terms of optimization it will rather need to be done for different GPU + batch size combinations. Fundamentally you want to use large SRAM tiles because that increases arithmetic intensity: you can do more computations per load from global memory. However, at the same time you want to split a computing task into as many blocks as possible since this reduces so-called tail effects where the last wave of a kernel does not have enough blocks to keep all of the streaming multiprocessors on a GPU busy. So for optimizing performance you will need to find a balance between these two problems which is affected by batch size because the number of blocks that the GPU can work on should be proportional to batch size. What I think would be the best solution is to compile the kernel multiple times with different template parameters and to then choose the best implementation at runtime based on the number of streaming multiprocessors on the GPU and teh shapes of the matrices. I am currently working on essentially this exact problem in ggerganov#4801 (but for regular matrix multiplication). |
Since the target architectures are Ampere and Ada Lovelace consider that copying data from VRAM to SRAM is more efficient using asynchronous data copies. This is both faster and reduces register pressure. It also allows you to load data and do computations in parallel. This optimization can be added later though. I will soon push a prototype to ggerganov#4801 which (I think) is close to optimal. |
The installation instructions seem to be missing
|
As hinted at by the error message, it can be fixed with |
…into flash-attn-cuda
…into flash-attn-cuda
@JohannesGaessler I'm trying to do a port of the implementation that Georgi did, but I'm running into issues where CUDA doesn't support fp16 operators. I need to use the functions from C:\proyects\llama.cpp\ggml-cuda.cu(6252): error : more than one conversion function from "half" to a built-in type applies: [C:\proyects\llama.cpp\build\ggml.vcxpro
j]
function "__half::operator float() const" (declared at line 217 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator short() const" (declared at line 235 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned short() const" (declared at line 238 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16
.hpp)
function "__half::operator int() const" (declared at line 241 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned int() const" (declared at line 244 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.h
pp)
function "__half::operator long long() const" (declared at line 247 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned long long() const" (declared at line 250 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_
fp16.hpp)
function "__half::operator __nv_bool() const" (declared at line 254 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
ss[hiisg*tph + tiih] = (s4.x + s4.y + s4.z + s4.w);
^
detected during instantiation of "void flash_attn_ext_f16<D,R>(const char *, const char *, const char *, const char *, float *, float, int, int, int, in
t, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=80, R=2]" at line 10370
C:\proyects\llama.cpp\ggml-cuda.cu(6261): error : no operator "+=" matches these operands [C:\proyects\llama.cpp\build\ggml.vcxproj]
operand types are: half += half
s += ss[hiisg*tph + i];
^
detected during instantiation of "void flash_attn_ext_f16<D,R>(const char *, const char *, const char *, const char *, float *, float, int, int, int, in
t, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=80, R=2]" at line 10370
C:\proyects\llama.cpp\ggml-cuda.cu(6264): error : identifier "__hfma" is undefined [C:\proyects\llama.cpp\build\ggml.vcxproj]
s = __hfma(s, __float2half(scale), mv);
^
detected during instantiation of "void flash_attn_ext_f16<D,R>(const char *, const char *, const char *, const char *, float *, float, int, int, int, in
t, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=80, R=2]" at line 10370
C:\proyects\llama.cpp\ggml-cuda.cu(6270): error : more than one conversion function from "half" to a built-in type applies: [C:\proyects\llama.cpp\build\ggml.vcxpro
j]
function "__half::operator float() const" (declared at line 217 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator short() const" (declared at line 235 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned short() const" (declared at line 238 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16
.hpp)
function "__half::operator int() const" (declared at line 241 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned int() const" (declared at line 244 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.h
pp)
function "__half::operator long long() const" (declared at line 247 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned long long() const" (declared at line 250 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_
fp16.hpp)
function "__half::operator __nv_bool() const" (declared at line 254 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
half ms = hexp(m - M);
^
detected during instantiation of "void flash_attn_ext_f16<D,R>(const char *, const char *, const char *, const char *, float *, float, int, int, int, in
t, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=80, R=2]" at line 10370
C:\proyects\llama.cpp\ggml-cuda.cu(6270): error : more than one conversion function from "half" to a built-in type applies: [C:\proyects\llama.cpp\build\ggml.vcxpro
j]
function "__half::operator float() const" (declared at line 217 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator short() const" (declared at line 235 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned short() const" (declared at line 238 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16
.hpp)
function "__half::operator int() const" (declared at line 241 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned int() const" (declared at line 244 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.h
pp)
function "__half::operator long long() const" (declared at line 247 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned long long() const" (declared at line 250 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_
fp16.hpp)
function "__half::operator __nv_bool() const" (declared at line 254 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
half ms = hexp(m - M);
^
detected during instantiation of "void flash_attn_ext_f16<D,R>(const char *, const char *, const char *, const char *, float *, float, int, int, int, in
t, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=80, R=2]" at line 10370
C:\proyects\llama.cpp\ggml-cuda.cu(6270): error : identifier "hexp" is undefined [C:\proyects\llama.cpp\build\ggml.vcxproj]
half ms = hexp(m - M);
^
detected during instantiation of "void flash_attn_ext_f16<D,R>(const char *, const char *, const char *, const char *, float *, float, int, int, int, in
t, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=80, R=2]" at line 10370
C:\proyects\llama.cpp\ggml-cuda.cu(6271): error : more than one conversion function from "half" to a built-in type applies: [C:\proyects\llama.cpp\build\ggml.vcxpro
j]
function "__half::operator float() const" (declared at line 217 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator short() const" (declared at line 235 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned short() const" (declared at line 238 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16
.hpp)
function "__half::operator int() const" (declared at line 241 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned int() const" (declared at line 244 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.h
pp)
function "__half::operator long long() const" (declared at line 247 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned long long() const" (declared at line 250 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_
fp16.hpp)
function "__half::operator __nv_bool() const" (declared at line 254 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
half vs = hexp(s - M);
^
detected during instantiation of "void flash_attn_ext_f16<D,R>(const char *, const char *, const char *, const char *, float *, float, int, int, int, in
t, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=80, R=2]" at line 10370
C:\proyects\llama.cpp\ggml-cuda.cu(6271): error : more than one conversion function from "half" to a built-in type applies: [C:\proyects\llama.cpp\build\ggml.vcxpro
j]
function "__half::operator float() const" (declared at line 217 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator short() const" (declared at line 235 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned short() const" (declared at line 238 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16
.hpp)
function "__half::operator int() const" (declared at line 241 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned int() const" (declared at line 244 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.h
pp)
function "__half::operator long long() const" (declared at line 247 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned long long() const" (declared at line 250 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_
fp16.hpp)
function "__half::operator __nv_bool() const" (declared at line 254 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
half vs = hexp(s - M);
^
detected during instantiation of "void flash_attn_ext_f16<D,R>(const char *, const char *, const char *, const char *, float *, float, int, int, int, in
t, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=80, R=2]" at line 10370
C:\proyects\llama.cpp\ggml-cuda.cu(6308): error : more than one conversion function from "half" to a built-in type applies: [C:\proyects\llama.cpp\build\ggml.vcxpro
j]
function "__half::operator float() const" (declared at line 217 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator short() const" (declared at line 235 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned short() const" (declared at line 238 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16
.hpp)
function "__half::operator int() const" (declared at line 241 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned int() const" (declared at line 244 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.h
pp)
function "__half::operator long long() const" (declared at line 247 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
function "__half::operator unsigned long long() const" (declared at line 250 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_
fp16.hpp)
function "__half::operator __nv_bool() const" (declared at line 254 of C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\include\cuda_fp16.hpp)
half ms0 = hexp(M0 - M);
^
detected during instantiation of "void flash_attn_ext_f16<D,R>(const char *, const char *, const char *, const char *, float *, float, int, int, int, in
t, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int) [with D=80, R=2]" at line 10370
C:\proyects\llama.cpp\ggml-cuda.cu(6308): error : more than one conversion function from "half" to a built-in type applies: [C:\proyects\llama.cpp\build\ggml.vcxpro
j] |
The issue is that cmake by default compiles device code for compute capabilities 5.2, 6.1, and 7.0. You can fix it for now by compiling with |
Could I assumed this is unsupported on p40? |
@sorasoras To use flash attention, the library requires the Tensor Cores, which are present starting from the Volta architecture (Nvidia V100) and newer. |
cuda : fix flash_attn kernel to produce same results as CPU
@ggerganov Have you tried conducting perplexity tests with the flash attention kernel? I've been experimenting with phi-2, and it's generating nonsensical text. Unfortunately, I don't have a llama model on my computer for testing, and I understand that the phi-2 model has precision issues. |
For now, test with LLaMA / Mistral models. When and if we get the performance, we will look into |
On my hardware (1x RTX 3090) the FlashAttention kernel is suffering from tail effects for perplexity calculations with a batch size of 512. You can see this in NSight Compute under "PM Sampling": The number of active warps for the last ~30% of the runtime is very low and this in turn reduces the GPU utilization. For more information you can look at the "Launch Statistics": The number of waves per SM is 2.08. The actual runtime of a kernel will be proportional to this value rounded up to the next integer. Ideally you want this number to be just below a whole integer so that for all waves the GPU utilization is high. The worse-case scenario is something like 1.01 because then you have to do 2 waves when you could probably do a slightly modified version of the kernel in a single wave which would make it ~2x faster. However, this is something that you should optimize at the end. The number of waves will depend strongly on implementation details, the specific GPU, and the batch size. The number of waves will be proportional to |
I made a PR with some very simple performance optimizations because it took me the same time as it would have been writing Github comments: #4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code has a lot of if(...){...}
statements. Try to avoid these if possible since they are slow on GPUs. If you cannot avoid them, try to provide the compiler with enough information via template parameters to resolve the branching at compile time so the runtime is unaffected.
const int warp_id = threadIdx.y; | ||
const int lane_id = threadIdx.x; | ||
|
||
const int num_warps = blockDim.y; // number of warps | ||
const int iq3 = blockIdx.z; | ||
const int iq2 = blockIdx.y; | ||
const int iq1 = blockIdx.x * Q; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code copies the values from threadIdx
and blockIdx
to registers. It should be very slightly faster to access them this way. But if register pressure is an issue it would make sense to remove these local variables and to instead use threadIdx
and blockIdx
directly.
If possible, make num_warps
a template parameter so the compiler gets more information and you don't have to use a register for it.
I don't see the PM Sampling section in my Nsight Compute. |
unroll 2 loops, int64_t -> int, 309 µs
Under Connect -> Metrics I just ticked "full" and then I got it. |
I don't see that section; it must be because my toolkit is outdated. |
At a batch size of 1 the size of the CUDA grid for the As a consequence both the memory and compute throughput are low. At a batch size of 512 the grid is larger but it is still very small: There are only 1.25 waves per SM and as a consequence ~40% of the runtime is lost to tail effects: If at all possible, the kernel should be rewritten in such a way that the work is distributed among more CUDA blocks or with more warps per CUDA block. |
ggml-cuda.cu
Outdated
// reduce the warps sequentially | ||
for (int sg = 1; sg < num_warps; ++sg) { | ||
__syncthreads(); | ||
|
||
// each simdgroup stores its output to shared memory, reusing sq | ||
if (warp_id == sg) { | ||
for (int j = 0; j < Q16; ++j) { | ||
for (int i = 0; i < D16; ++i) { | ||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); | ||
} | ||
} | ||
} | ||
|
||
__syncthreads(); | ||
|
||
// the first simdgroup accumulates the results from the other simdgroups | ||
if (warp_id == 0) { | ||
for (int j = lane_id; j < Q; j += NW) { | ||
const half S0 = ss[j*T + 0]; | ||
const half S1 = ss[j*T + sg*SH + 0]; | ||
|
||
const half M0 = ss[j*T + 1]; | ||
const half M1 = ss[j*T + sg*SH + 1]; | ||
|
||
const half M = __hmax(M0, M1); | ||
|
||
const half ms0 = hexp(M0 - M); | ||
const half ms1 = hexp(M1 - M); | ||
|
||
const half S = S0*ms0 + S1*ms1; | ||
|
||
ss[j*T + 0] = S; | ||
ss[j*T + 1] = M; | ||
|
||
ss[j*T + C + j ] = ms0; | ||
ss[j*T + C + j + sg*SH] = ms1; | ||
} | ||
|
||
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 | ||
for (int j = 0; j < Q16; ++j) { | ||
half16x16_a ms0; | ||
half16x16_a ms1; | ||
half16x16_b t; | ||
half16x16_acc t2; | ||
|
||
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); | ||
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); | ||
|
||
for (int i = 0; i < D16; ++i) { | ||
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); | ||
nvcuda::wmma::mma_sync(t2, ms1, t, zr); | ||
|
||
// convert accumulator to matrix_b | ||
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); | ||
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); | ||
|
||
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); | ||
} | ||
} | ||
} | ||
} | ||
|
||
// store result to shared memory (reuse sq) | ||
if (warp_id == 0) { | ||
for (int j = 0; j < Q16; ++j) { | ||
for (int i = 0; i < D16; ++i) { | ||
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); | ||
} | ||
} | ||
} | ||
|
||
// final rescale with 1/S and store to global memory | ||
if (warp_id == 0) { | ||
for (int j = 0; j < Q && iq1 + j < ne01; ++j) { | ||
const half S = ss[j*T + 0]; | ||
|
||
for (int i0 = 0; i0 < D; i0 += NW) { | ||
const int i = i0 + lane_id; | ||
if (i >= D) { | ||
break; | ||
} | ||
|
||
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My guess based on just static code analysis is that this is the part of the code that is causing issues when you increase the number of warps. But since I am not able to increase the number of warps (and still get correct results) this could be very wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it works correctly for up to 8 warps - after that it runs out of shared memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, but how do I set the number of warps? If I just set e.g. nwarps = 4
on line 10940 I get NaN with
export model_name=llama_2-7b && export quantization=q4_0
./perplexity --n-gpu-layers 99 --model models/nvme/${model_name}-${quantization}.gguf --file wikitext-2-raw/wiki.test.raw --mlock --chunks 1
after that it runs out of shared memory
It's implemented in a convoluted and tedious way but it is possible to raise the shared memory limit for CUDA kernels so that a single block can use 100% of the shared memory of an SM. For example:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about the state of this branch. It works on my branch:
https://github.com/ggerganov/llama.cpp/tree/gg/flash-attn
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 713a6a89..99925259 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -10932,7 +10932,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
- const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;
+ const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 4;
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
dim3 block_dim(32, nwarps, 1);
LLAMA_CUBLAS=1 make -j && ./perplexity -m ./models/llama-7b-v2/ggml-model-q4_0.gguf -f ./build-cublas/wikitext-2-raw/wiki.test.raw -ngl 99
perplexity: calculating perplexity over 655 chunks, batch_size=512
perplexity: 0.36 seconds per pass - ETA 3.90 minutes
[1]4.3401,[2]4.8393,[3]5.4619,[4]6.0629,[5]6.1965,[6]6.0945,[7]6.2706,^C
Previus work: llama.cpp#778
Previously, the initiative to implement Flash Attention to improve inference performance in llama.cpp had already been introduced. However, it was assumed that this approach would yield the expected results on the CPU, for that reason, it was discarded, and no further follow-up was given.
Flash Attention is actually designed to enhance GPU resource utilization through the use of tensor cores and shared memory, which is 30 times faster than global memory (VRAM), reducing unnecessary readings and writings.
Implementing this algorithm is particularly challenging because it requires taking into account hardware limitations and conflicts that can degrade performance, even when everything is expected to be fine (a.k.a. shared memory banks conflicts).
Tasks to be carried out during the execution of this project:
Run test: