diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index a20aa37bcc1e2..f8e6b82e5f2cf 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -4,13 +4,14 @@ import pytest import torch +from tests.kernels.utils import override_backend_env_variable from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams, SequenceData, SequenceGroupMetadata) -from vllm.utils import get_open_port +from vllm.utils import STR_FLASH_ATTN_VAL, get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -387,3 +388,108 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), vars(decode_meta_actual)): assert attr_expected[1] == attr_actual[1] + + +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches_with_prefix_caching(enforce_eager, monkeypatch): + override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) + + model_runner = _create_model_runner( + "facebook/opt-125m", + seed=0, + dtype="float16", + enforce_eager=enforce_eager, + max_num_batched_tokens=100000, + max_num_seqs=100000, + enable_chunked_prefill=True, + enable_prefix_caching=True, + ) + + seq_lens: List[int] = [] + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + # Use a large number of blocks to test longer sequences + # with chunked prefill and prefix caching + block_tables = {0: list(range(128))} + + # case 1: prefix_cache_len <= context_len: + seq_len = 1000 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + token_chunk_size=200, + computed_block_nums=range(10), + ) + seq_data.update_num_computed_tokens(200) + seq_group_metadata_list.append(seq_group_metadata) + + # case 2: context_len < prefix_cache_len < seq_len + seq_len = 1000 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + token_chunk_size=100, + computed_block_nums=range(10), + ) + seq_data.update_num_computed_tokens(80) + seq_group_metadata_list.append(seq_group_metadata) + + # case 3: prefix_cache_len >= seq_len + seq_len = 1000 + seq_lens.append(seq_len) + seq_data = SequenceData(list(range(seq_len))) + seq_group_metadata = SequenceGroupMetadata( + request_id="test_0", + is_prompt=True, + seq_data={0: seq_data}, + sampling_params=SamplingParams(temperature=0), + block_tables=block_tables, + token_chunk_size=100, + computed_block_nums=range(10), + ) + seq_data.update_num_computed_tokens(50) + seq_group_metadata_list.append(seq_group_metadata) + + model_input = model_runner.prepare_model_input(seq_group_metadata_list) + (input_tokens, input_positions, attn_metadata) = ( + model_input.input_tokens, + model_input.input_positions, + model_input.attn_metadata, + ) + + # The nums of tokens to be computed are: + # - for the first sequence: no matched, and all of 200 tokens are + # left to be recomputed + # - for the second sequence: partially cached, and 20 tokens are + # left to recomputed + # - for the be third sequence: fully cached, and only 1 token is + # left to be recomputed + assert len(input_tokens) == 221 + assert len(input_positions) == 221 + + torch.testing.assert_close( + attn_metadata.query_start_loc, + torch.tensor([0, 200, 220, 221], + dtype=torch.int32, + device=attn_metadata.query_start_loc.device)) + + torch.testing.assert_close( + attn_metadata.seq_start_loc, + torch.tensor([0, 400, 580, 730], + dtype=torch.int32, + device=attn_metadata.seq_start_loc.device)) + + torch.testing.assert_close( + attn_metadata.context_lens_tensor, + torch.tensor([200, 160, 149], + dtype=torch.int32, + device=attn_metadata.context_lens_tensor.device)) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ce7a7198dc400..319f8f6b53d4b 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -425,6 +425,13 @@ def _add_seq_group( block_tables = inter_data.block_tables computed_block_nums = inter_data.computed_block_nums + if chunked_prefill_enabled and inter_data.prefix_cache_hit: + raise RuntimeError( + "Chunked prefill and prefix caching cannot be used " + "simultaneously with flashinfer backend, try switching " + "to flash-attn backend by setting the environment variable " + "\"VLLM_ATTENTION_BACKEND=FLASH_ATTN\"") + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 0375d3488eb15..1e4e28760d7bd 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -150,6 +150,13 @@ def _add_seq_group( block_tables = inter_data.block_tables computed_block_nums = inter_data.computed_block_nums + if chunked_prefill_enabled and inter_data.prefix_cache_hit: + raise RuntimeError( + "Chunked prefill and prefix caching can only be used " + "simultaneously with flash-attn backend, try switching " + "to flash-attn backend by setting the environment variable " + "\"VLLM_ATTENTION_BACKEND=FLASH_ATTN\"") + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 0af04399a4b31..b692ce4db32e6 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -680,10 +680,15 @@ def access_all_blocks_in_seq( for block in block_table: block.last_accessed = access_time - def compute_full_blocks_in_seq(self, seq: Sequence): + def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int): if seq.seq_id not in self.block_tables: return - max_full_block = seq.get_len() // self.block_size - 1 + # We ensure at least 1 token to prefill even fully matched in the + # model runner, so "computing computed_blocks as it is" is safe here. + max_full_block = min( + seq.get_prompt_len(), + seq.data.get_num_computed_tokens() + + token_chunk_size) // self.block_size block_table = self.block_tables[seq.seq_id] if max_full_block == -1: return @@ -717,10 +722,11 @@ def get_common_computed_block_ids( ids_list = [self.get_all_computed_blocks(seq) for seq in seqs] return commonprefix([ids for ids in ids_list if ids != []]) - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): if self.enable_caching: for seq in seq_group.get_seqs(): - self.compute_full_blocks_in_seq(seq) + self.compute_full_blocks_in_seq(seq, token_chunk_size) def get_prefix_cache_hit_rate(self, device: Device) -> float: if device == Device.GPU: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b7d9451f18067..8e92ebeb1f9eb 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -286,7 +286,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): self._last_access_blocks_tracker.update_last_access( seq.seq_id, now) - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): # The only need for mark block as computed is for prefix caching, # while currently we could determine whether one block is computed # or not by check whether it has content hash. diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index 3d864a73f91d0..4188a5c71eaeb 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -80,7 +80,8 @@ def get_common_computed_block_ids(self, seq_group: SequenceGroup) -> List[int]: return None # type: ignore - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): pass def get_prefix_cache_hit_rate(self, device: Device) -> float: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index becd0d2e7f849..96f8dd851b2f4 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -115,7 +115,8 @@ def get_common_computed_block_ids( pass @abstractmethod - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 3b716e32032c1..56145ba7f6e01 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1145,7 +1145,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # will crash the vLLM instance / will not retry. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: self.block_manager.mark_blocks_as_computed( - scheduled_seq_group.seq_group) + scheduled_seq_group.seq_group, + scheduled_seq_group.token_chunk_size) scheduler_time = time.perf_counter() - scheduler_start_time # Add this to scheduler time to all the sequences that are currently diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 793f03456e997..30967f373dd89 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -498,20 +498,41 @@ def _compute_for_prefix_cache_hit( and self.sliding_window is None and inter_data.is_prompt) inter_data.prefix_cache_hit = prefix_cache_hit - if self.chunked_prefill_enabled and prefix_cache_hit: - raise RuntimeError( - "chunked prefill cannot be used with prefix caching now.") # If prefix cache is hit, advance context length to bypass # hit blocks. Accordingly, input tokens, position and query length # have to be updated. if prefix_cache_hit: assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][context_len:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][context_len:] + prefix_cache_len = len(computed_block_nums) * self.block_size + # When prefix caching meets chunked prefill, we would be in + # one of the following three cases: + context_len = inter_data.context_lens[seq_idx] + seq_len = inter_data.seq_lens[seq_idx] + if prefix_cache_len <= context_len: + # Do normal chunked prefill. + pass + elif context_len < prefix_cache_len < seq_len: + # Advance the context_len to seq_len to prefill non-cached + # parts of the prompt. + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][(prefix_cache_len - context_len):] + inter_data.input_positions[ + seq_idx] = inter_data.input_positions[seq_idx][( + prefix_cache_len - context_len):] + context_len = prefix_cache_len + elif seq_len <= prefix_cache_len: + # The current partial sequence is fully cache hit, + # and no further computation is needed. In this case, + # We leave at least 1 token for chunked prefill to prevent + # empty sequences in the attention computation. + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][(seq_len - 1 - context_len):] + inter_data.input_positions[ + seq_idx] = inter_data.input_positions[seq_idx][( + seq_len - 1 - context_len):] + context_len = seq_len - 1 + inter_data.context_lens[seq_idx] = context_len inter_data.query_lens[ seq_idx] = inter_data.seq_lens[seq_idx] - context_len