Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Optimize the CPU overheads in FlashAttention custom op #10733

Merged
merged 1 commit into from
Nov 28, 2024

Conversation

WoosukKwon
Copy link
Collaborator

@WoosukKwon WoosukKwon commented Nov 28, 2024

With piece-wise CUDA graphs, we have to make sure that the attention custom op causes minimal CPU overheads. This PR made a few changes to optimize the CPU overheads in the FlashAttention custom op:

  1. We directly use torch.ops.vllm_flash_attn_c.varlen_fwd rather than flash_attn_varlen_func, since FlashAttnFunc which inherits torch.autograd.Function causes unnecessary overheads.
  2. We move the reshapes and shape check logics to outside of the custom op, so that they can be done at the CUDA graph capture time.

Results of python benchmarks/benchmark_latency.py (opt-125m) on a single H100 GPU:

  • V1 main: 227 ms
  • V1 this PR: 192 ms
  • V0 + 8-step: 130 ms

Next step: further reduce the unnecessary CPU ops inside the FlashAttention op.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 28, 2024
@youkaichao
Copy link
Member

youkaichao commented Nov 28, 2024

well, I think I forgot to update the v1 flash attention file, after #10558 , you don't need the torch.ops.vllm.unified_v1_flash_attention call.

nvm

@@ -203,23 +209,31 @@ def unified_v1_flash_attention(
v_scale,
)

attn_output = flash_attn_varlen_func(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also update the corresponding v0 code?

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at profile results on #9856, this saves about 60µs off of the CPU time spent in each flash attention call (approx 300µs -> 240µs)

Thanks!

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with Kaichao's comment, thanks for quickly improving this. The failing test is due to neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16 and unrelated

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
@WoosukKwon
Copy link
Collaborator Author

@youkaichao @mgoin As we merged vllm-project/flash-attention#30, we don't have to directly use torch.ops.vllm_flash_attn_c.varlen_fwd. We can just use flash_attn_varlen_func as we currently do. Both V0 and V1 already gets the benefits after vllm-project/flash-attention#30.

@WoosukKwon
Copy link
Collaborator Author

One weird phenomenon I found is that V1 has a spike in latency:

Avg latency: 0.20093455887205589 seconds
10% percentile latency: 0.1931818482640665 seconds
25% percentile latency: 0.19354040725738741 seconds
50% percentile latency: 0.19391279752017 seconds
75% percentile latency: 0.19426249974640086 seconds
90% percentile latency: 0.1961068181961309 seconds
99% percentile latency: 0.3368887884780999 seconds

This is highly reproducible on my dev machine. Can this be because of Python gc or something like that?

@WoosukKwon WoosukKwon merged commit 98f47f2 into main Nov 28, 2024
15 of 18 checks passed
@WoosukKwon WoosukKwon deleted the v1-flash-opt branch November 28, 2024 17:01
@robertgshaw2-redhat
Copy link
Collaborator

One weird phenomenon I found is that V1 has a spike in latency:

Avg latency: 0.20093455887205589 seconds
10% percentile latency: 0.1931818482640665 seconds
25% percentile latency: 0.19354040725738741 seconds
50% percentile latency: 0.19391279752017 seconds
75% percentile latency: 0.19426249974640086 seconds
90% percentile latency: 0.1961068181961309 seconds
99% percentile latency: 0.3368887884780999 seconds

This is highly reproducible on my dev machine. Can this be because of Python gc or something like that?

It’s probably the prefix caching …

@comaniac
Copy link
Collaborator

Hmm but benchmark_latency.py does sample each prompts separately: https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_latency.py#L36

@comaniac
Copy link
Collaborator

comaniac commented Nov 29, 2024

Hmm but benchmark_latency.py does sample each prompts separately: https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_latency.py#L36

Just found that it has a warmup phase. It's still possible due to prefix caching if all prompts are cached then. Suggest to explicitly disable prefix caching to double check.

@WoosukKwon
Copy link
Collaborator Author

@comaniac @robertgshaw2-neuralmagic You're right. The latency becomes stable when prefix caching is turned off.

Avg latency: 0.1945609479948568 seconds
10% percentile latency: 0.19310778125654907 seconds
25% percentile latency: 0.19390572598786093 seconds
50% percentile latency: 0.19475348049309105 seconds
75% percentile latency: 0.195164829317946 seconds
90% percentile latency: 0.19570096801035106 seconds
99% percentile latency: 0.1962820820847992 seconds

afeldman-nm pushed a commit to neuralmagic/vllm that referenced this pull request Dec 2, 2024
…ject#10733)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
sleepwalker2017 pushed a commit to sleepwalker2017/vllm that referenced this pull request Dec 13, 2024
…ject#10733)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
BKitor pushed a commit to BKitor/vllm that referenced this pull request Dec 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants