From 7214d0617b9aba77e44b9fc4a81a28f64be10a42 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 27 Oct 2024 13:41:46 -0700 Subject: [PATCH 1/4] use out for flash attention Signed-off-by: youkaichao --- tests/kernels/test_flash_attn.py | 10 ++++-- vllm/attention/backends/flash_attn.py | 50 ++++++++++++--------------- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 35c29c5bd1028..b468e7bd293c5 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -116,7 +116,8 @@ def test_flash_attn_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = flash_attn_with_kvcache( + output = torch.empty_like(query) + flash_attn_with_kvcache( q=query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, @@ -126,7 +127,8 @@ def test_flash_attn_with_paged_kv( cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, window_size=window_size, - ).squeeze(1) + out=output, + ) ref_output = ref_paged_attn(query=query, key_cache=key_cache, @@ -197,7 +199,8 @@ def test_varlen_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = flash_attn_varlen_func( + output = torch.empty_like(query) + flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, @@ -210,6 +213,7 @@ def test_varlen_with_paged_kv( window_size=window_size, block_table=block_tables, softcap=soft_cap if soft_cap is not None else 0, + out=output, ) ref_output = ref_paged_attn( diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ffa05e80623ac..8d57f2a447210 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -575,7 +575,9 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - output = torch.ops.vllm.unified_flash_attention( + output = torch.empty_like(query) + + torch.ops.vllm.unified_flash_attention( query, key, value, @@ -590,17 +592,19 @@ def forward( self.sliding_window, self.alibi_slopes, self.logits_soft_cap, + output=output, ) return output @torch.library.custom_op("vllm::unified_flash_attention", - mutates_args=["kv_cache"]) + mutates_args=["kv_cache", "output"]) def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + output: torch.Tensor, num_heads: int, head_size: int, num_kv_heads: int, @@ -612,7 +616,7 @@ def unified_flash_attention( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: +): current_metadata = get_forward_context() assert current_metadata is not None @@ -625,6 +629,8 @@ def unified_flash_attention( key = key.view(-1, num_kv_heads, head_size) value = value.view(-1, num_kv_heads, head_size) + output = output.view(-1, num_heads, head_size) + if kv_cache.numel() > 0: key_cache = kv_cache[0] value_cache = kv_cache[1] @@ -652,17 +658,16 @@ def unified_flash_attention( # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] + decode_output = output[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] + prefill_output = output[:num_prefill_tokens] key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. if (kv_cache.numel() == 0 or prefill_meta.block_tables is None @@ -670,7 +675,7 @@ def unified_flash_attention( # normal attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - prefill_output = flash_attn_varlen_func( + flash_attn_varlen_func( q=query, k=key, v=value, @@ -683,12 +688,13 @@ def unified_flash_attention( window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, + out=prefill_output, ) else: # prefix-enabled attention assert prefill_meta.seq_lens is not None max_seq_len = max(prefill_meta.seq_lens) - prefill_output = flash_attn_varlen_func( # noqa + flash_attn_varlen_func( # noqa q=query, k=key_cache, v=value_cache, @@ -702,6 +708,7 @@ def unified_flash_attention( alibi_slopes=alibi_slopes, block_table=prefill_meta.block_tables, softcap=logits_soft_cap, + out=prefill_output, ) if decode_meta := attn_metadata.decode_metadata: @@ -710,7 +717,7 @@ def unified_flash_attention( # because different queries might have different lengths. assert decode_meta.max_decode_query_len is not None if decode_meta.max_decode_query_len > 1: - decode_output = flash_attn_varlen_func( + flash_attn_varlen_func( q=decode_query, k=key_cache, v=value_cache, @@ -724,10 +731,11 @@ def unified_flash_attention( alibi_slopes=alibi_slopes, softcap=logits_soft_cap, block_table=decode_meta.block_tables, + out=decode_output, ) else: # Use flash_attn_with_kvcache for normal decoding. - decode_output = flash_attn_with_kvcache( + flash_attn_with_kvcache( q=decode_query.unsqueeze(1), k_cache=key_cache, v_cache=value_cache, @@ -738,21 +746,8 @@ def unified_flash_attention( window_size=window_size, alibi_slopes=alibi_slopes, softcap=logits_soft_cap, - ).squeeze(1) - - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) - - # Chunked prefill does not work with speculative decoding. - # Therefore, the query length for decode should be 1 in chunked prefill. - assert decode_meta is not None - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) + out=decode_output.unsqueeze(1), + ) @unified_flash_attention.register_fake @@ -760,6 +755,7 @@ def _( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, + output: torch.Tensor, num_heads: int, head_size: int, num_kv_heads: int, @@ -771,5 +767,5 @@ def _( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - return torch.empty_like(query) +): + return From 315ea6f9ea4c915d5c65c756126349d7fe10d516 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 27 Oct 2024 13:43:28 -0700 Subject: [PATCH 2/4] fix Signed-off-by: youkaichao --- tests/kernels/test_flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index b468e7bd293c5..8f70ad686e5a5 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -127,7 +127,7 @@ def test_flash_attn_with_paged_kv( cache_seqlens=kv_lens_tensor, softcap=soft_cap if soft_cap is not None else 0, window_size=window_size, - out=output, + out=output.unsqueeze(1), ) ref_output = ref_paged_attn(query=query, From e3687f203f56dfa10aa3c6bc4755380f19f24022 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 27 Oct 2024 13:46:38 -0700 Subject: [PATCH 3/4] pos arg Signed-off-by: youkaichao --- vllm/attention/backends/flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 8d57f2a447210..523c77d77bed8 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -581,6 +581,7 @@ def forward( query, key, value, + output, self.num_heads, self.head_size, self.num_kv_heads, @@ -592,7 +593,6 @@ def forward( self.sliding_window, self.alibi_slopes, self.logits_soft_cap, - output=output, ) return output From 969225a355b5dd2088880ed48b7193c74a6dd606 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 27 Oct 2024 13:51:13 -0700 Subject: [PATCH 4/4] schema Signed-off-by: youkaichao --- vllm/attention/backends/flash_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 523c77d77bed8..940da242135e0 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -616,7 +616,7 @@ def unified_flash_attention( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, -): +) -> None: current_metadata = get_forward_context() assert current_metadata is not None @@ -767,5 +767,5 @@ def _( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, -): +) -> None: return