diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 9ad7c41e48b68..25b86176f630e 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -559,25 +559,32 @@ def forward( self.kv_cache_dtype, k_scale, v_scale) - if attn_type != AttentionType.ENCODER: - # Decoder self-attention supports chunked prefill. - # Encoder/decoder cross-attention requires no chunked - # prefill (100% prefill or 100% decode tokens, no mix) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - else: + if attn_type == AttentionType.ENCODER: # Encoder attention - chunked prefill is not applicable; # derive token-count from query shape & and treat them # as 100% prefill tokens assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens + num_encoder_tokens = attn_metadata.num_encoder_tokens num_decode_tokens = 0 - - if attn_type == AttentionType.DECODER: + elif attn_type == AttentionType.DECODER: + # Decoder self-attention supports chunked prefill. + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens # Only enforce this shape-constraint for decoder # self-attention assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens + else: # attn_type == AttentionType.ENCODER_DECODER + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + if attn_metadata.num_encoder_tokens is not None: + num_encoder_tokens = attn_metadata.num_encoder_tokens + else: + num_encoder_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. @@ -585,8 +592,8 @@ def forward( # QKV for prefill. query = query[:num_prefill_tokens] if key is not None and value is not None: - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] + key = key[:num_encoder_tokens] + value = value[:num_encoder_tokens] assert query.shape[0] == num_prefill_tokens assert decode_query.shape[0] == num_decode_tokens