-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash Attention V2 #485
Comments
I use benchmarks/benchmark_throughput.py to test flash attention V2, but it doesn't seem to have any effect. my test step is like this,
test time is like this, Further analysis of performance, i found that the replaced part (flash attention V2) cost is too small, only at the beginning of the execution, i am confused , for flash attention V2, what can we do for vllm? |
hi, which version specifically? i don't think flash v2 support has been released yet, so you would have to install from git. also there are still some open PRs to bump xformers to flash-attn v2.0.4 bugfix release (facebookresearch/xformers#816). |
I tried this as well, and there was no improvement in the benchmarks after switching to flash-attn v2. I will try to profile the benchmark script. |
i dont think this one really works. because flash-attn's another important feature is to decrease the highly gpu-memory usage in super long-context like more than 5k. |
Hi @nivibilla, thanks for submitting the issue. The latest version of @tmm1 @Zhuqln To my understanding, the overall speedup should depend on your workload. At the inference time, FlashAttention is only used for the prompt inputs, and never used for the decoding inputs. For many workloads, the decoding stage takes a majority of the total execution time, so changing to FlashAttention V2 may not give a notable speedup. However, for other workload like text summarization where the prompts are very long, I believe computing attention for the prompt inputs will take a significant portion of the execution time, and thus FlashAttention V2 will have a huge impact on the overall performance. |
@WoosukKwon thanks for the explanation! |
Hi, this is inaccurate since the code is still forcing See benchmark results in facebookresearch/xformers#832 |
Thanks for the details @WoosukKwon . I just have a question. Why FlashAttention could not be used for decoding phase? |
Its tiling strategy is not optimized for Q with seqlen=1 Dao-AILab/flash-attention#427 (comment) |
I'm delighted to engage in this discussion. Your report has been immensely helpful, but I do have some questions. For instance, I'm curious to know if there's a performance comparison available between trtLLM and vLLM. Such information would be greatly beneficial in guiding my decision on which framework to choose. |
you assume that in summarization task most of the workload is by decoding the input. in my experimentation I saw that the scale of generation is much bigger. so, if you generate only 1-5 token then most of the workload is decoding input, there will be dependency on input length and flash attention 2 will be advantageous (as it linear in input length while naive implementation is exponential in input length). but if you generate a considerable amount of tokens, then that factor is prominent, the input decoding is negligible, and flash attention 2 has no power here. https://github.com/matanhol/summarization_with_flash_attn_2_simulation |
I tried installing vllm with flash attn but it didn't work, my attempts: Install flash attention:
```bash
# my current vllm setup without flash
# pip install --upgrade pip
# pip install torch==2.2.1
# pip install vllm==0.4.1
# flash attn https://amzn-aws.slack.com/archives/C06Q26TNN8G/p1724182667464149
# flash-attn>=2.5.8
# pip install flash-attn
# Collabs's setup with flash
# vllm 0.5.4
# vllm-flash-attn 2.6.1
# flash-attn 2.6.3
# torch 2.4.0
# Python 3.10.8
# try to install flash attn in a new py env
python3.11 -m venv ~/.virtualenvs/flash_attn_test_py10
source ~/.virtualenvs/flash_attn_test/bin/activate
pip install --upgrade pip
pip install -e ~/snap-cluster-setup
pip list | grep vllm
pip list | grep torch
pip list | grep flash-attn
pip list | grep vllm-flash-attn
# # didn't work
# pip install torch==2.2.1
# pip install vllm==0.4.1
# MAX_JOBS=4 pip install flash-attn --no-build-isolation --force
# this installed flash but vllm didn't say in it's output it was using it
pip install torch==2.4.0
pip install vllm==0.5.4
pip install flash-attn==2.6.3
pip install vllm-flash-attn==2.6.1
python ~/snap-cluster-setup/py_src/evals/boxed_acc_eval.py --model internlm/internlm2_5-1_8b --hf_gen_type vllm --path_2_eval_dataset ~/snap-cluster-setup/data/MATH/test --max_tokens 2048 --batch_size 100 --end 100 -n 1 --shuffle True --mode dryrun 2>&1 | tee $LOG_FILE && echo "Log file created at: $LOG_FILE"
# later try with py 3.10
# python3xxx -m venv ~/.virtualenvs/flash_attn_test_py10
# source ~/.virtualenvs/flash_attn_test_py10/bin/activate
# pip install --upgrade pip
# pip install -e ~/snap-cluster-setup
# pip install torch==2.4.0
# pip install vllm==0.5.4
# pip install flash-attn==2.6.3
# pip install vllm-flash-attn==2.6.1 |
my setting is python 3.11, that is what I really want/need. |
related vllm general issues for vllm ver: #2747 |
https://github.com/Dao-AILab/flash-attention
Flash attention v2 was released claiming 2x speedups. Making an issue to remind myself to have a look at it. And also if anyone else wants to try implement it.
The text was updated successfully, but these errors were encountered: