Skip to content

Commit

Permalink
[Bugfix] Fix block table for seqs that have prefix cache hits (#7018)
Browse files Browse the repository at this point in the history
  • Loading branch information
zachzzc authored Aug 3, 2024
1 parent 0c25435 commit fb2c1c8
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 3 deletions.
56 changes: 56 additions & 0 deletions tests/prefix_caching/test_prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@

import pytest

from tests.kernels.utils import override_backend_env_variable
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager_v1 import CachedBlockAllocator
from vllm.utils import Device

from ..models.utils import check_outputs_equal

MODELS = [
"facebook/opt-125m",
]


@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
Expand Down Expand Up @@ -76,3 +83,52 @@ def test_eviction(num_blocks: int, ):
assert (realloc_block != new_block)
assert (new_block.block_hash == new_block_hash)
assert (new_block.block_number == 2)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
def test_mixed_requests(
hf_runner,
vllm_runner,
example_prompts,
model: str,
backend: str,
dtype: str,
max_tokens: int,
cached_position: int,
use_v2_block_manager: bool,
monkeypatch,
) -> None:
"""
Test the case when some sequences have the prefix cache hit
and the others don't. The cached position determines where
the sequence is at among the batch of prefills.
"""
override_backend_env_variable(monkeypatch, backend)

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

cached_prompt = example_prompts[cached_position]
with vllm_runner(
model,
dtype=dtype,
enable_prefix_caching=True,
use_v2_block_manager=use_v2_block_manager,
) as vllm_model:
# Run the first prompt so the cache is populated
vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)

# Run all the promopts
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
12 changes: 9 additions & 3 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False

self.input_builder = input_builder
self.runner = input_builder.runner
Expand All @@ -219,7 +220,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"):

def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool):
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
Expand Down Expand Up @@ -252,7 +253,7 @@ def _add_seq_group(
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
Expand Down Expand Up @@ -281,9 +282,14 @@ def build(self, seq_lens: List[int], query_lens: List[int],
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled)
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)

device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
Expand Down

0 comments on commit fb2c1c8

Please sign in to comment.