-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] Optimize the CPU overheads in FlashAttention custom op #10733
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
nvm |
@@ -203,23 +209,31 @@ def unified_v1_flash_attention( | |||
v_scale, | |||
) | |||
|
|||
attn_output = flash_attn_varlen_func( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also update the corresponding v0 code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at profile results on #9856, this saves about 60µs off of the CPU time spent in each flash attention call (approx 300µs -> 240µs)
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with Kaichao's comment, thanks for quickly improving this. The failing test is due to neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16
and unrelated
e4f8b06
to
456980b
Compare
@youkaichao @mgoin As we merged vllm-project/flash-attention#30, we don't have to directly use |
One weird phenomenon I found is that V1 has a spike in latency:
This is highly reproducible on my dev machine. Can this be because of Python gc or something like that? |
It’s probably the prefix caching … |
Hmm but benchmark_latency.py does sample each prompts separately: https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_latency.py#L36 |
Just found that it has a warmup phase. It's still possible due to prefix caching if all prompts are cached then. Suggest to explicitly disable prefix caching to double check. |
@comaniac @robertgshaw2-neuralmagic You're right. The latency becomes stable when prefix caching is turned off.
|
…ject#10733) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
…ject#10733) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
…ject#10733) Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
With piece-wise CUDA graphs, we have to make sure that the attention custom op causes minimal CPU overheads. This PR made a few changes to optimize the CPU overheads in the FlashAttention custom op:
We directly usetorch.ops.vllm_flash_attn_c.varlen_fwd
rather thanflash_attn_varlen_func
, sinceFlashAttnFunc
which inheritstorch.autograd.Function
causes unnecessary overheads.Results of
python benchmarks/benchmark_latency.py
(opt-125m) on a single H100 GPU:Next step: further reduce the unnecessary CPU ops inside the FlashAttention op.