Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[spec decode] [4/N] Move update_flash_attn_metadata to attn backend #7571

Merged
merged 7 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def copy_blocks(
) -> None:
raise NotImplementedError

def advance_step(self, num_seqs: int, num_queries: int):
raise NotImplementedError


@dataclass
class AttentionMetadata:
Expand Down
45 changes: 45 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,51 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self, num_seqs: int, num_queries: int):
"""
Update metadata in-place to advance one decode step.
"""
# GPU in-place update is currently called separately through
# custom_ops.advance_step(). See draft_model_runner. TODO(will): Move
# this logic to the backend.

# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph

assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )

assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)

assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )

assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )

assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs

# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)


class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):
Expand Down
34 changes: 1 addition & 33 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,38 +97,6 @@ def __init__(
self.flashinfer_prefill_workspace_buffer = None
self.flashinfer_prefill_wrapper = None

def _update_flash_attn_metadata(self, attn_metadata, num_seqs,
num_queries):
assert isinstance(attn_metadata, FlashAttentionMetadata)

if num_seqs != num_queries:
assert num_seqs > num_queries
assert attn_metadata.use_cuda_graph

assert attn_metadata.num_prefills == 0
assert attn_metadata.num_prefill_tokens == 0
assert attn_metadata.num_decode_tokens == num_seqs
assert attn_metadata.slot_mapping.shape == (num_seqs, )

assert len(attn_metadata.seq_lens) == num_seqs
assert attn_metadata.seq_lens_tensor.shape == (num_seqs, )
assert attn_metadata.max_query_len == 1
assert attn_metadata.max_prefill_seq_len == 0
assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens)

assert attn_metadata.query_start_loc.shape == (num_queries + 1, )
assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, )

assert attn_metadata.context_lens_tensor.shape == (num_queries, )

assert attn_metadata.block_tables.shape[0] == num_seqs

# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
attn_metadata.seq_lens[i] += 1
attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens)

def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):

Expand Down Expand Up @@ -166,7 +134,7 @@ def _gpu_advance_step(
# Update attn_metadata
attn_metadata = model_input.attn_metadata
assert isinstance(attn_metadata, FlashAttentionMetadata)
self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries)
attn_metadata.advance_step(num_seqs, num_queries)

# Update GPU tensors
ops.advance_step(num_seqs=num_seqs,
Expand Down
Loading