Skip to content

Commit

Permalink
[bugfix] [AMD] add multi-step advance_step to ROCmFlashAttentionMetad…
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker authored Sep 20, 2024
1 parent 3dc5ce0 commit 7e177f4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
58 changes: 57 additions & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch

Expand All @@ -15,6 +15,9 @@
from vllm.logger import init_logger
from vllm.platforms import current_platform

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

logger = init_logger(__name__)

_PARTITION_SIZE_ROCM = 512
Expand Down Expand Up @@ -180,6 +183,59 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph

assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )

assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)

assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )

assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )

assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs

# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)

ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)


class ROCmFlashAttentionMetadataBuilder(
CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

logger = init_logger(__name__)

MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"]
MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"]


def seq_output_builder():
Expand Down

0 comments on commit 7e177f4

Please sign in to comment.