diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 06b178798dcd9..69faa6d343eda 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -471,9 +471,19 @@ def build(self, seq_lens: List[int], query_lens: List[int], # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] for i, block_table in enumerate(self.block_tables): if block_table: - input_block_tables[i, :len(block_table)] = block_table + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + block_tables = torch.from_numpy(input_block_tables).to( device=device, non_blocking=True) else: