Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] use out argument for flash attention #9740

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def test_flash_attn_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_with_kvcache(
output = torch.empty_like(query)
flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
Expand All @@ -126,7 +127,8 @@ def test_flash_attn_with_paged_kv(
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
).squeeze(1)
out=output.unsqueeze(1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion - do the unsqueeze during creation (either by passing in the modified shape or just .empty_like(...).unsqueeze(1). I think that will be cleaner

)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
Expand Down Expand Up @@ -197,7 +199,8 @@ def test_varlen_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = flash_attn_varlen_func(
output = torch.empty_like(query)
flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -210,6 +213,7 @@ def test_varlen_with_paged_kv(
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
out=output,
)

ref_output = ref_paged_attn(
Expand Down
50 changes: 23 additions & 27 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,13 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

output = torch.ops.vllm.unified_flash_attention(
output = torch.empty_like(query)

torch.ops.vllm.unified_flash_attention(
query,
key,
value,
output,
self.num_heads,
self.head_size,
self.num_kv_heads,
Expand All @@ -596,11 +599,12 @@ def forward(


@torch.library.custom_op("vllm::unified_flash_attention",
mutates_args=["kv_cache"])
mutates_args=["kv_cache", "output"])
def unified_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
Expand All @@ -612,7 +616,7 @@ def unified_flash_attention(
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
) -> None:

current_metadata = get_forward_context()
assert current_metadata is not None
Expand All @@ -625,6 +629,8 @@ def unified_flash_attention(
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)

output = output.view(-1, num_heads, head_size)

if kv_cache.numel() > 0:
key_cache = kv_cache[0]
value_cache = kv_cache[1]
Expand Down Expand Up @@ -652,25 +658,24 @@ def unified_flash_attention(

# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
decode_output = output[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
prefill_output = output[:num_prefill_tokens]
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens

prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
if (kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
prefill_output = flash_attn_varlen_func(
flash_attn_varlen_func(
q=query,
k=key,
v=value,
Expand All @@ -683,12 +688,13 @@ def unified_flash_attention(
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
out=prefill_output,
)
else:
# prefix-enabled attention
assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens)
prefill_output = flash_attn_varlen_func( # noqa
flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -702,6 +708,7 @@ def unified_flash_attention(
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
)

if decode_meta := attn_metadata.decode_metadata:
Expand All @@ -710,7 +717,7 @@ def unified_flash_attention(
# because different queries might have different lengths.
assert decode_meta.max_decode_query_len is not None
if decode_meta.max_decode_query_len > 1:
decode_output = flash_attn_varlen_func(
flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
Expand All @@ -724,10 +731,11 @@ def unified_flash_attention(
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
out=decode_output,
)
else:
# Use flash_attn_with_kvcache for normal decoding.
decode_output = flash_attn_with_kvcache(
flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
Expand All @@ -738,28 +746,16 @@ def unified_flash_attention(
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)

if prefill_output is None:
assert decode_output is not None
return decode_output.view(num_decode_tokens, hidden_size)
if decode_output is None:
assert prefill_output is not None
return prefill_output.view(num_prefill_tokens, hidden_size)

# Chunked prefill does not work with speculative decoding.
# Therefore, the query length for decode should be 1 in chunked prefill.
assert decode_meta is not None
decode_output = decode_output.squeeze(1)
output = torch.cat([prefill_output, decode_output], dim=0)
return output.view(num_tokens, hidden_size)
out=decode_output.unsqueeze(1),
)


@unified_flash_attention.register_fake
def _(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
num_heads: int,
head_size: int,
num_kv_heads: int,
Expand All @@ -771,5 +767,5 @@ def _(
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
) -> None:
return