Skip to content

Commit

Permalink
[Bugfix] Fix bug of xformer prefill for encoder-decoder (vllm-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangxu-google authored Oct 12, 2024
1 parent 89feb4c commit 00298e0
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,34 +559,41 @@ def forward(
self.kv_cache_dtype,
k_scale, v_scale)

if attn_type != AttentionType.ENCODER:
# Decoder self-attention supports chunked prefill.
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
else:
if attn_type == AttentionType.ENCODER:
# Encoder attention - chunked prefill is not applicable;
# derive token-count from query shape & and treat them
# as 100% prefill tokens
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens
num_encoder_tokens = attn_metadata.num_encoder_tokens
num_decode_tokens = 0

if attn_type == AttentionType.DECODER:
elif attn_type == AttentionType.DECODER:
# Decoder self-attention supports chunked prefill.
num_prefill_tokens = attn_metadata.num_prefill_tokens
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens
# Only enforce this shape-constraint for decoder
# self-attention
assert key.shape[0] == num_prefill_tokens + num_decode_tokens
assert value.shape[0] == num_prefill_tokens + num_decode_tokens
else: # attn_type == AttentionType.ENCODER_DECODER
# Encoder/decoder cross-attention requires no chunked
# prefill (100% prefill or 100% decode tokens, no mix)
num_prefill_tokens = attn_metadata.num_prefill_tokens
if attn_metadata.num_encoder_tokens is not None:
num_encoder_tokens = attn_metadata.num_encoder_tokens
else:
num_encoder_tokens = attn_metadata.num_prefill_tokens
num_decode_tokens = attn_metadata.num_decode_tokens

output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]
# QKV for prefill.
query = query[:num_prefill_tokens]
if key is not None and value is not None:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]
key = key[:num_encoder_tokens]
value = value[:num_encoder_tokens]

assert query.shape[0] == num_prefill_tokens
assert decode_query.shape[0] == num_decode_tokens
Expand Down

0 comments on commit 00298e0

Please sign in to comment.