Skip to content

Commit

Permalink
[Bugfix] lookahead block table with cuda graph max capture (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#8340)

[Bugfix] Ensure multistep lookahead allocation is compatible with cuda graph max capture (vllm-project#8340)
  • Loading branch information
alexm-redhat authored and MengqingCao committed Sep 29, 2024
1 parent c74907f commit 0507842
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,19 @@ def build(self, seq_lens: List[int], query_lens: List[int],
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]

block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True)
else:
Expand Down

0 comments on commit 0507842

Please sign in to comment.