-
Notifications
You must be signed in to change notification settings - Fork 9.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : add Flash Attention #5021
Conversation
Since we are doing this from scratch, wouldn't it be better to remove the custom attention mask entirely and pass a list of KV cells used in each sequence? Considering our implementation of batching, I think we should be looking at implementing something closer to paged attention rather than flash attention. I suppose it is possible to convert the mask to a list of sequences in the kernels, but it would be less efficient. |
Yes, we can pass list instead of mask. I am not sure of the format though - if each list has different length I feel it will hinder the GPU performance. Edit: I just got an idea - we can pass both the |
We could use a vector with dimension |
It seems that vLLM has added a new version of paged attention since it looked into the implementation (vllm-project/vllm#1348). I am not sure what are the changes, but I think it is worth looking into what they are doing. The kernel is in https://github.com/vllm-project/vllm/blob/main/csrc/attention/attention_kernels.cu |
Alibi could also be done in this kernel. |
Regarding the Alibi, I feel reinterpreting it as a It remains to be seen though if the Will take a look at the vLLM code and I've updated the description with some of the things from this discussion |
@ggerganov @slaren Together with @JohannesGaessler and @FSSRepo we are working on the same thing over at Pints-AI#1 which we intend to do a pull to llamacpp once work is done. However, I think we will converge into this one. Given the amount of work here, @ggerganov @slaren how do you want to organise this? The 3 of us are in a temporary discord group actually to work this out, perhaps we can use that? What are your thoughts? |
Discord is not an option for me - I prefer to communicate over Github issues / discussions / e-mail. Happy to see you have started work on the CUDA implementation. Please take into account the proposed API here - note that it is still a WIP and can change. I can review the implementation that you have when you think it is in a good state. Would prefer PR's that are compatible with this branch so we can verify correctness using |
@ggerganov Got it. Let us work on a plan to converge with this PR. |
|
Any performance numbers? |
e0ba0da
to
52ae085
Compare
With default settings on a RTX 4090 , + |
Am I missing something ? It only increases t/s ? Right ? Not VRAM usage per ctx size ? |
KQ doesn't need to be materialized in global memory with flash attention, and with large contexts that was often the biggest tensor in the compute buffer. So it should reduce the size of the compute buffer substantially with large contexts. |
@ggerganov F-AT is not enabled ROCM in general,right? |
Edit: actually now seemingly getting same results for all (jittery) llama-bench output for various -t values./llama-bench -m models/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf -fa 1,0 -p 512 -t 1,2,3,4,5,6,7,8
|
Metal and other GPU backends with full offload only uses one thread, however in Metal the number of threads is also used as the number of command buffers. |
@slaren ah ok, thanks for the explanation! I'm not seeing any effect of |
It's currently disabled, yes. |
Sadly I'm not seeing any benefit from this. No reduction in VRAM usage, no speedup, even when fully offloading. Infact, I'm only seeing slower speeds when using partial offloading. |
For me (Windows, CUDA, 24GB VRAM) the difference is definitely there, but it depends on the model and I have best results with a large amount of context data. The most pronounced for me is Mixtral-8x7B-Instruct-v0.1-requant-imat-IQ3_XS which I can fully offload. It Edit: I saw the below "old timings" across at least 4x runs each last night, but today w/o FA is hitting close to 39-40 t/s, so must have been an edge case there, but FA seemed to help with it. With FA:
Without FA: (updated)
old w/o timings
Other models are less remarkable, but I'm able to store a lot more context. New tests: Llamabench with -p 512,1024 is less dramatic but measurable, TG ~46 -> ~50:
The differences are more obvious at -p 8096, 16192, 32384: From PP 819 -> 1005 @ 16K, and OOM -> 879 @ 32K.
|
Performance on Macbook Air M2, 24GB using latest llama.cpp, before and after using the Without Flash Attention:
With Flash Attention:
TL;DR: Generation speed increases from 8.70 t/s to 9.69 t/s, memory usage decreases slightly, prompt processing is not tested in this case. |
Hi is server has flash attention yet ? Or is it automatically using flash attention ? edit: just add -fa too in server got it |
Hi, I am having issues building this on CUDA 11.4 now after this PR. Notably, I am getting This is not the first time this has happened, previously we added |
@LostRuins can you check whether this fix #7019 works? |
It seems this only applies to a low context like 4K. Testing a very small LLM on my system with a context size of 13.000 Tokens and no GQA, the difference is massive. VRAM savings from 2.8 to 1.2 GB, Text Generation from 37 to 71 token/s, pp from 1300 token/s to 2300 token/s. Great work! |
From the dialogue above, I think I understand that the support for -fa needs to be coded per backend. Can someone confirm that? Not having much luck using -fa for the vulkan backend. I do not expect said support to materialize either, just want to clarify. |
It does need to be implemented per backend. |
* ggml : add ggml_flash_attn_ext API * ggml : fix GQA support in ggml_flash_attn_ext * ggml : online attention (CPU) * metal : initial implementation * metal : f16 precision * metal : reduce branches * metal : specialize for head size * wip : 8 rows per simd group * wip : 4 rows per simd group * wip : template for rows per warp * metal : parallelize across KV size * metal : parallel reduce across heads * metal : efficient flash_attn_f16 implementation * metal : avoid redundant loads of the attention * metal : scale and mask in matrix form * metal : fix comment * llama : avoid ggml_cast, use F32 query * metal : add parallel reduce version (disabled) * metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments * metal : add tests, fix scaling, support C > 32 * metal : improve precision * ggml : fix f16 mad * metal : minor * metal : support Q > 8 * tests : add ATTN tests * metal : disable buffer allocation logs * tests : more * metal : faster inner loop for C == 32 * metal : fix array initialization * tests : ifdef * ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext * ggml : fix ggml_soft_max mask requirement * cuda : fix soft_max to use correct mask size * cuda : add flash_attn kernel (wip) * metal : optimize softmax for C > 32 * metal : optimize softmax * tests : minor fix * cuda : avoid zeroing fragments * tests : update dims * cuda : fix __hisinf() result check * cuda : avoid warp_reduce for smax * cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) * cuda : make loops use the same loop values Thanks Johannes again for the tip * cuda : unroll some of the loops * cuda : avoid __hisinf branches * cuda : use half2 in softmax * cuda : switch to 1 warp for bs > 16 * cuda : speed-up reduce part of the kernel * cuda : unroll Q*K^T loop * cuda : fix -INF block check * cuda : simplify softmax * cuda : fix matrix names * cuda : minor * llama : adapt to F16 KQ_pos * llama : adapt new models to F16 KQ_mask * ggml : fix F16 store (ARM NEON) * llama : fix type of KQ_mask and KQ_pos * ggml : fix CPU soft_max * tests : add hs=256 * cuda : fix build * metal : improve perf via smaller int registers * cuda : adapt soft_max to F16 mask and pos * CUDA: faster FlashAttention, kernel for bs == 1 * 16 cols for Phi-2 * no vec for hs, no hs==256 ncols==32 for Volta * adjust kernel selection logic * 4 warps, 256 stride for all D * no ncols == 64 * Multiple parallel blocks for batch size 1 * fix compile warnings * fix excessive KQ_b loads * fix cmake build * fix KV cache padding, NaN from INFINITY (ggerganov#6438) * llama : flash_attn cparam + fix defrag * server: support flash_attn param * server: bench: enable flash_attn param * CUDA: refactor host code, dyn. par. blocks * fix flash_attn_vec_f16 race condition * flush softmax exp below threshold to 0 * store temp KQ in registers * Calculate KQ as FP32 if KQV has GGML_PREC_F32 * Add __hgt2_mask implementation for CUDA 11 * fix KQ FP32 precision fpr parallel_blocks > 1 * llama-bench : add -fa,--flash-attn arg * metal : add BS=1 kernel for flash attention (ggerganov#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel * metal : use F32 attention accumulators * batched-bench : add fattn arg * llama : simplify llama_build_kv_store ggml-ci * llama : adapt build_olmo to changes * ggml : fix arm fp16 store on windows * metal : clean-up * metal : clean-up kernel code * metal : minor * tests : remove benchmarks ggml-ci * ggml : fix avx512 const correctness ggml-ci * ggml : fix soft_max with bias on CPU ggml-ci * common : print --flash-attn in help * ggml : fix num dimensions in ggml_flash_attn_ext * llama : force disable flash attention for incompatible models * ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci * cuda : uint -> uint32_t * cuda : "constexpr dim3" -> "const dim3" ggml-ci * cuda : try to fix __hgt2_mask ggml-ci * ggml : add TODO's for F16/F32 mask/pos support in other backends * llama : replace bool need_kq_pos with use_alibi * llama : prep ALiBi support for BERT models ggml-ci * llama : fix n_batch requirements ggml-ci * cont * server : add help for --flash-attn arg * llama : disable FA for AMD * tests : remove TMP_ATTN_BENCH ggml-ci * llama : support save/load state with FA enabled ggml-ci * ci : add CUDA save-load-state tests ggml-ci * llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci * llama : fix copy-paste errors, add TODO * llama : disallow incompatible states * llama : update llama_state_get_size after v_trans field * metal : remove tmp log * llama : add static reminder for llama_state_get_size * metal : fix max nsg ggml-ci * ci : fix arg order ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> Co-authored-by: Pierrick HYMBERT <pierrick.hymbert@gmail.com>
Why metal test |
@LukeLIN-web because you're compiling with LLAMA_CUBLAS (which is deprecated by the way, use LLAMA_CUDA). You can't use CUDA on a MacBook |
Any updates on context shift compability? |
ref #3365
Setting up what's needed for Flash Attention support in
ggml
andllama.cpp
The proposed operator performs:
Suggestions and comments for the API are welcome.
Looking for help in implementing efficient GPU kernels - please open PR to this branch if you have proposals
ggml
API:ggml_flash_attn_ext()
llama.cpp
use inllm_build_kqv()
test-backend-ops
testGGML_PREC_F32
support (CUDA) (CUDA: faster FlashAttention for batch sizes > 1 #6646)GGML_PREC_F32
support (Metal)Changes to
ggml
/llama
GGML_OP_FLASH_ATTN_EXT
andggml_flash_attn_ext()
call(before merging we can consider reusing the old
GGML_OP_FLASH_ATTN
and removing the legacy code)mask
type to F16 forggml_soft_max_ext()
and require that it is padded toGGML_KQ_MASK_PAD 32
n_kv
denoting the number of computed tokens from the KV cache is now padded to 128 (from 32) to support larger FA blocks without making out-of-bounds accessllama_context_params.n_batch
that can be used isGGML_KQ_MASK_PAD 32
to avoid out-of-bounds access in the FA kernels for small batch sizeV
tensor is no longer transposed when storing it in the KV cacheThings to consider
ggml_add()
? (low-prio)Testing
main
,server
: add-fa
llama-bench
: add-fa 1
Benchmark
Baseline:
FA kernel:
Text-generation after long prompt:
References