From b3fa9d61407a9820c460af149dcc5e7c8dd7fd66 Mon Sep 17 00:00:00 2001 From: rickyx Date: Wed, 13 Nov 2024 22:05:19 +0000 Subject: [PATCH 1/4] Clean? Signed-off-by: rickyx --- tests/core/block/test_prefix_caching_block.py | 181 +++++++++- tests/core/test_scheduler.py | 179 +++++++++- tests/core/utils.py | 20 +- tests/prefix_caching/test_prefix_caching.py | 83 +++++ vllm/core/block/cpu_gpu_block_allocator.py | 15 +- vllm/core/block/interfaces.py | 36 +- vllm/core/block/naive_block.py | 11 +- vllm/core/block/prefix_caching_block.py | 255 +++++++++----- vllm/core/block_manager.py | 23 +- vllm/core/interfaces.py | 4 + vllm/core/placeholder_block_space_manager.py | 3 + vllm/core/scheduler.py | 332 +++++++++++++----- vllm/sequence.py | 3 + 13 files changed, 913 insertions(+), 232 deletions(-) diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index d325b9606843e..bbeb4b3a58f2a 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -5,9 +5,14 @@ import pytest +from tests.core.utils import create_dummy_sequence +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator -from vllm.core.block.prefix_caching_block import (PrefixCachingBlock, +from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, + PrefixCachingBlock, PrefixCachingBlockAllocator) +from vllm.sequence import Logprob +from vllm.utils import Device class TestPrefixCachingBlock: @@ -726,18 +731,71 @@ def test_touch_block(): token_ids=common_token_ids, allocator=allocator, ) - block_ids = [block.block_id for block in blocks] + block_hashes = [block.content_hash for block in blocks] # The allocated blocks should be marked as touched # but not computed. - computed_block_ids = allocator.get_computed_block_ids( - [], block_ids, skip_last_block_id=False) + computed_block_ids = allocator.find_cached_blocks_prefix( + block_hashes) assert len(computed_block_ids) == 0 allocator.mark_blocks_as_computed([]) - computed_block_ids = allocator.get_computed_block_ids( - [], block_ids, skip_last_block_id=False) + computed_block_ids = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes) assert len(computed_block_ids) == common_blocks + @staticmethod + def test_find_cached_blocks_prefix(): + """ + This test verifies the behavior of find_cached_blocks_prefix. + """ + block_size = 4 + num_blocks = 8 + total_test_blocks = 12 + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + + token_ids = list(range(total_test_blocks * block_size)) + block_tokens_seq1 = token_ids[:num_blocks * block_size] + blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=block_tokens_seq1, + allocator=allocator, + ) + block_hashes_seq1 = [block.content_hash for block in blocks_seq1] + allocator.mark_blocks_as_computed([]) + + # All blocks should be cached. + cached_blocks_seq1 = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq1) + assert len(cached_blocks_seq1) == num_blocks + + # Free the first sequence. + for block in blocks_seq1: + allocator.free(block) + + # All blocks should be still be cached if not required to be allocated. + cached_blocks = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq1) + assert len(cached_blocks) == num_blocks + + block_tokens_seq2 = token_ids[num_blocks * block_size:] + blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=block_tokens_seq2, + allocator=allocator, + ) + block_hashes_seq2 = [block.content_hash for block in blocks_seq2] + allocator.mark_blocks_as_computed([]) + cached_blocks = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq2) + assert len(cached_blocks) == len(blocks_seq2) + + # Half of the blocks from seq1 should still be cached. + num_evicted_blocks = len(blocks_seq2) + cached_blocks = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq1) + assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks + @staticmethod def create_immutable_chain( block_size: int, @@ -762,3 +820,114 @@ def create_immutable_chain( blocks.append(prev_block) return blocks + + +class TestComputedBlocksTracker: + + @staticmethod + def _get_mock_allocator(): + return MagicMock(spec=PrefixCachingBlockAllocator) + + @staticmethod + def test_get_num_cached_tokens(): + """ + Test it correctly computes the number of cached tokens for a given + sequence: + + - The cache token count is derived from the number of cached blocks. + - The cache token count is updated when the allocator is updated. + - When a sequence is removed, the cache token count should be updated + accordingly. + + # TODO(rickyx): This behaviour for prefill sequence is a hack until + we fix the computed blocks tracking. + - The cache token count for prefill sequence doesn't change while + the sequence is in continuous prefill (chunked prefill). + """ + block_size = 4 + mock_allocator = TestComputedBlocksTracker._get_mock_allocator() + tracker = ComputedBlocksTracker( + allocator=mock_allocator, + block_size=block_size, + enable_caching=True, + ) + + # Not yet allocated. + tokens = [0, 1, 2, 3, 4, 5] + seq1 = create_dummy_sequence(request_id=0, + token_ids=tokens, + block_size=block_size) + mock_allocator.find_cached_blocks_prefix.return_value = [] + assert tracker.get_num_cached_tokens(seq1) == 0 + + mock_allocator.find_cached_blocks_prefix.return_value = [ + None + ] # 1 block cached. + # Result is cached for prefill sequence. + assert tracker.get_num_cached_tokens(seq1) == 0 + + # Mark the sequence as non-prefill. + seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed. + assert not seq1.is_prefill() + + # Recomputes for decoding sequence. + assert tracker.get_num_cached_tokens(seq1) == 4 + + # Append new tokens to the sequence. + num_new_tokens = 3 + for i in range(num_new_tokens): + seq1.append_token_id(i, {i: Logprob(logprob=0.0)}) + + assert tracker.get_num_cached_tokens(seq1) == 4 + + # Update the allocator. + mock_allocator.find_cached_blocks_prefix.return_value = [ + None + ] * 2 # 2 blocks cached. + assert tracker.get_num_cached_tokens(seq1) == 8 + + # Remove the sequence. + tracker.remove_seq(seq1.seq_id) + + # Re-create the sequence with the same request id to simulate recompute. + seq1 = create_dummy_sequence(request_id=0, + token_ids=tokens, + block_size=block_size) + mock_allocator.find_cached_blocks_prefix.return_value = [ + ] # no cached block + assert tracker.get_num_cached_tokens(seq1) == 0 + + @staticmethod + def test_correct_block_hash(): + """ + Test that the block hash is correctly computed for a sequence (should + match the underlying block allocator's block hash). So the number of + cached tokens is correctly retrieved. + """ + block_size = 4 + allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching", + num_gpu_blocks=16, + num_cpu_blocks=16, + block_size=block_size, + ) + gpu_allocator = allocator._allocators[Device.GPU] + + tracker = ComputedBlocksTracker( + allocator=allocator, + block_size=block_size, + enable_caching=True, + ) + + tokens = list(range(block_size * 4)) # 4 blocks. + seq = create_dummy_sequence(request_id=0, + token_ids=tokens, + block_size=block_size) + _ = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=tokens, + allocator=gpu_allocator, + ) + allocator.mark_blocks_as_computed([]) + + assert tracker.get_num_cached_tokens(seq) == len(tokens) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 5ff32be611592..8f6de84e566e7 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -12,9 +12,9 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SequenceGroup -from .utils import (append_new_token, append_new_token_seq_group, - create_dummy_prompt, get_sequence_groups, - schedule_and_update_computed_tokens) +from .utils import (append_new_token, append_new_token_seq, + append_new_token_seq_group, create_dummy_prompt, + get_sequence_groups, schedule_and_update_computed_tokens) def test_scheduler_add_seq_group(): @@ -305,6 +305,8 @@ def initialize_scheduler( block_size=4, num_cpu_blocks=8, num_gpu_blocks=8, + enable_prefix_caching=False, + enable_chunked_prefill=False, ): block_size = block_size scheduler_config = SchedulerConfig( @@ -312,8 +314,15 @@ def initialize_scheduler( max_num_batched_tokens=max_token_budget, max_num_seqs=max_num_seqs, max_model_len=max_model_len, + enable_chunked_prefill=enable_chunked_prefill, + ) + cache_config = CacheConfig( + block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=enable_prefix_caching, ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = num_cpu_blocks cache_config.num_gpu_blocks = num_gpu_blocks scheduler = Scheduler(scheduler_config, cache_config, lora_config) @@ -800,3 +809,165 @@ def test_scheduling_budget(): assert budget.num_curr_seqs == 0 budget.subtract_num_seqs(seq_group.request_id, 2) assert budget.num_curr_seqs == 0 + + +@pytest.mark.parametrize("enable_prefix_caching", [True, False]) +def test_prefix_caching_aware_prefills(enable_prefix_caching): + """ + Test the below scenario: + + For 3 sequences, seqA, seqB, seqC, share the first block as prefix. + + The test verifies the below scenarios: + 1. SeqA is first scheduled. + 2. SeqB and SeqC can be prefilled together in a single schedule round + even though there are not enough token budgets to prefill both without + considering prefix caching. + """ + + block_size = 4 + max_num_batched_tokens = 12 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_token_budget=max_num_batched_tokens, + max_num_seqs=max_seq_group, + max_model_len=max_num_batched_tokens, + enable_prefix_caching=enable_prefix_caching, + ) + + seqA_tokens = list(range(8)) + num_shared_tokens = 4 + seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range( + 12, 16)) # Shared prefix first 4. + seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range( + 16, 20)) # Shared prefix first 4. + + seqA, seqA_group = create_dummy_prompt("0", + prompt_tokens=seqA_tokens, + block_size=block_size) + seqB, seqB_group = create_dummy_prompt("1", + prompt_tokens=seqB_tokens, + block_size=block_size) + seqC, seqC_group = create_dummy_prompt("2", + prompt_tokens=seqC_tokens, + block_size=block_size) + + # Schedule seqA prefill. + scheduler.add_seq_group(seqA_group) + metas, out, _ = scheduler.schedule() + assert (len(out.scheduled_seq_groups) == 1 + and out.scheduled_seq_groups[0].seq_group == seqA_group) + assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens) + + # Schedule seqA decode. + append_new_token_seq_group(len(seqA_tokens), seqA_group, 999) + metas, out, _ = scheduler.schedule() + + assert len(out.scheduled_seq_groups) == 1 + assert out.scheduled_seq_groups[0].seq_group == seqA_group + assert out.scheduled_seq_groups[0].token_chunk_size == 1 + + # Schedule seqB and seqC prefills should work with prefix caching. + scheduler.add_seq_group(seqB_group) + scheduler.add_seq_group(seqC_group) + metas, out, _ = scheduler.schedule() + + if enable_prefix_caching: + assert len(out.scheduled_seq_groups) == 2 + assert set([ + out.scheduled_seq_groups[0].seq_group, + out.scheduled_seq_groups[1].seq_group, + ]) == set([seqB_group, seqC_group]) + assert len(metas) == 2 + for meta in metas: + assert meta.token_chunk_size == 8 + assert (len(meta.computed_block_nums) == num_shared_tokens // + block_size) # 1 Block for the 8 tokens. + else: + assert len(out.scheduled_seq_groups) == 1 + assert len(metas) == 1 + assert metas[0].token_chunk_size == 8 + assert len(metas[0].computed_block_nums) == 0 # No blocks computed. + + +def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( +): + """ + This test verifies that we don't schedule new prefills if there's already + a continuous prefill in progress even though the new prefills with shared + prefix can fit in the token budget: + + - SeqA is being chunked prefill. + - SeqB with the same prompt shouldn't be scheduled for prefill even though + there's enough token budget to prefill the cached tokens. + - Neither should seqC be scheduled. + + - When seqA is in decoding phase, seqB and seqC can be scheduled. + - Entire seqB should be prefilled since it's a full prefix cache hit. + - SeqC would be partially prefilled with the prefix shared, and the + remaining unique tokens would be prefilled (rounded down to be + block-size aligned). + """ + + block_size = 2 + max_num_batched_tokens = 4 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_token_budget=max_num_batched_tokens, + max_num_seqs=max_seq_group, + max_model_len=100, + enable_prefix_caching=True, + enable_chunked_prefill=True, + ) + + seqA_tokens = list(range(8)) + seqB_tokens = seqA_tokens + seqC_shared_prefix_len = 4 + seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20)) + + seqA, seqA_group = create_dummy_prompt("0", + prompt_tokens=seqA_tokens, + block_size=block_size) + seqB, seqB_group = create_dummy_prompt("1", + prompt_tokens=seqB_tokens, + block_size=block_size) + + # Chunked prefill seqA. + scheduler.add_seq_group(seqA_group) + metas, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.scheduled_seq_groups[0].seq_group == seqA_group + assert out.scheduled_seq_groups[0].token_chunk_size == 4 + + # seqB should not be scheduled with ongoing prefills. + scheduler.add_seq_group(seqB_group) + metas, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.scheduled_seq_groups[0].seq_group == seqA_group + assert out.scheduled_seq_groups[0].token_chunk_size == 4 + + # both seqB and seqC can now be scheduled with seqA is over. + # seqA is in decoding phase. + append_new_token_seq(seqA, 999) + seqC, seqC_group = create_dummy_prompt("2", + prompt_tokens=seqC_tokens, + block_size=block_size) + scheduler.add_seq_group(seqC_group) + metas, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 3 + + metas = {meta.request_id: meta for meta in metas} + assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode + assert (metas[seqB_group.request_id].token_chunk_size == 8 + ) # Fully cached prefill + assert ( + metas[seqC_group.request_id].token_chunk_size == 6 + ), "A partial prefix of C (4 tokens) should be prefilled, with the " + "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " + "then be rounded down to 2 tokens on block size, thus 6 tokens in total." diff --git a/tests/core/utils.py b/tests/core/utils.py index cd0caa4704e11..b2e99d55110ff 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -11,7 +11,7 @@ def create_dummy_prompt( request_id: str, - prompt_length: int, + prompt_length: int = -1, block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, best_of: int = 1, @@ -26,6 +26,7 @@ def create_dummy_prompt( # Create dummy prompt sequence with tokens 0...block_size-1 # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) + prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), inputs=token_inputs(prompt_tokens, prompt=prompt_str), @@ -42,6 +43,15 @@ def create_dummy_prompt( return prompt, seq_group +def create_dummy_sequence(request_id: int, token_ids: List[int], + block_size: int) -> Sequence: + return Sequence( + seq_id=request_id, + inputs=token_inputs(token_ids), + block_size=block_size, + ) + + def create_dummy_prompt_encoder_decoder( request_id: str, decoder_prompt_length: int, @@ -194,11 +204,15 @@ def append_new_token(out, token_id: int): def schedule_and_update_computed_tokens(scheduler): metas, out, _ = scheduler.schedule() - for s, meta in zip(out.scheduled_seq_groups, metas): - s.seq_group.update_num_computed_tokens(meta.token_chunk_size) + for s in out.scheduled_seq_groups: + s.seq_group.update_num_computed_tokens(s.token_chunk_size) return metas, out +def append_new_token_seq(seq: Sequence, token_id: int): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): seq_group.update_num_computed_tokens(token_chunk_size) for seq in seq_group.get_seqs(): diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 50723dbb610ac..6ed3305bb6665 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -2,10 +2,14 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ +import random +from typing import List, Optional + import pytest from tests.kernels.utils import override_backend_env_variable from vllm import SamplingParams, TokensPrompt +from vllm.transformers_utils.tokenizer import get_tokenizer from ..models.utils import check_outputs_equal @@ -27,6 +31,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("cached_position", [0, 1]) +@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) @pytest.mark.parametrize("block_size", [16]) def test_mixed_requests( hf_runner, @@ -37,6 +42,7 @@ def test_mixed_requests( dtype: str, max_tokens: int, cached_position: int, + enable_chunked_prefill: bool, block_size: int, monkeypatch, ) -> None: @@ -55,6 +61,7 @@ def test_mixed_requests( model, dtype=dtype, enable_prefix_caching=True, + enable_chunked_prefill=enable_chunked_prefill, block_size=block_size, ) as vllm_model: # Run the first prompt so the cache is populated @@ -88,6 +95,82 @@ def test_mixed_requests( ) +@pytest.mark.parametrize("chunk_size", [None, 64]) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("prefix_ratio", [0.0, 0.5, 1.0]) +def test_partially_cached_prefills( + vllm_runner, + model: str, + chunk_size: Optional[int], + prefix_ratio: float, +) -> None: + """ """ + # Prompts with various lengths. + prompt_lens = [16, 32, 63, 64, 65, 127, 128, 256, 512] + + # NOTE(rickyx): We could potentially run against HF here and with a longer + # decode len, but seems there's quite some invariance in the generated + # tokens, i.e. decoded tokens might mismatch indeterministically across + # runs. There's even mismatch between non-prefix-cached enabled and HF + # ones. + max_tokens = 1 + max_model_len = 2 * max(prompt_lens) + if chunk_size is None: + max_num_batched_tokens = max_model_len + max_num_seqs = len(prompt_lens) + else: + max_num_batched_tokens = chunk_size + max_num_seqs = chunk_size + + tokenizer = get_tokenizer(model, trust_remote_code=True) + random.seed(0) + vocab = list(tokenizer.get_vocab().values()) + + batch_prompts: List[List[str]] = [] + for prompt_len in prompt_lens: + num_prefix_tokens = int(prefix_ratio * prompt_len) + num_unique_tokens = prompt_len - num_prefix_tokens + + prefix_token_ids = random.choices(vocab, k=num_prefix_tokens) + unique_token_ids = random.choices(vocab, k=num_unique_tokens) + + prompt_token_ids = prefix_token_ids + unique_token_ids + prompt_str = tokenizer.decode(prompt_token_ids) + + # First 2 prompts might not have prefix cache hit even with shared + # prefix cache since not yet computed. + batch_prompts.append([prompt_str] * 2) + + # Next prompts should be fully cached, guaranteed to have prefix cache + # hit. + batch_prompts.append([prompt_str]) + + dtype = "half" + outputs = {} # type: ignore + for enable_prefix_caching in (True, False): + with vllm_runner( + model, + dtype=dtype, + enable_prefix_caching=enable_prefix_caching, + enable_chunked_prefill=chunk_size is not None, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_seqs, + max_model_len=max_model_len, + enforce_eager=True, + ) as vllm_model: + outputs[enable_prefix_caching] = [] + for prompts in batch_prompts: + outputs[enable_prefix_caching] += vllm_model.generate_greedy( + prompts, max_tokens=max_tokens) + + check_outputs_equal( + outputs_0_lst=outputs[False], + outputs_1_lst=outputs[True], + name_0="vllm_no_apc", + name_1="vllm", + ) + + @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) def test_unstable_prompt_sequence( vllm_runner, diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 9727f6e19b84e..3197af3c2b7a4 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -306,14 +306,6 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: device = Device.GPU return self._allocators[device].mark_blocks_as_computed(block_ids) - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].get_computed_block_ids( - prev_computed_block_ids, block_ids, skip_last_block_id) - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: # Prefix caching only supported on GPU. @@ -342,6 +334,13 @@ def get_and_reset_swaps(self) -> List[Tuple[int, int]]: self._swap_mapping.clear() return list(mapping.items()) + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + return self._allocators[device].find_cached_blocks_prefix(block_hashes) + class NullBlock(Block): """ diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 72bbab1dcea5d..06f4851af3466 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -159,12 +159,6 @@ def mark_blocks_as_accessed(self, block_ids: List[int], def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass - @abstractmethod - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - pass - @abstractmethod def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: @@ -192,6 +186,13 @@ def get_prefix_cache_hit_rate(self) -> float: class NoFreeBlocksError(ValueError): pass + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + ) -> List[int]: + pass + class DeviceAwareBlockAllocator(ABC): @@ -207,9 +208,12 @@ def allocate_immutable_block(self, prev_block: Optional[Block], pass @abstractmethod - def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device) -> List[Block]: + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + ) -> List[Block]: pass @abstractmethod @@ -246,12 +250,6 @@ def mark_blocks_as_accessed(self, block_ids: List[int], def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass - @abstractmethod - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - pass - @abstractmethod def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: @@ -284,3 +282,11 @@ def allocate_or_get_null_block(self) -> Block: def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 9341a518d11c6..a2af5ad6362c1 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -262,13 +262,6 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """ pass - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - """No prefix caching here => return empty list - """ - return [] - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: """Determine blocks that can be skipped in prefill. @@ -329,6 +322,10 @@ def swap_in(self, blocks: List[Block]) -> None: def get_prefix_cache_hit_rate(self) -> float: return -1 + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + # Not applicable for naive block allocator. + return [] + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 57527e39b9bdd..df3ef393136c8 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -4,10 +4,12 @@ from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device, + DeviceAwareBlockAllocator) from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.sequence import Sequence PrefixHash = int @@ -534,26 +536,6 @@ def block_is_computed(self, block_id: int) -> bool: else: return block_id in self.evictor - def get_computed_block_ids(self, - prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool = True) -> List[int]: - prev_prefix_size = len(prev_computed_block_ids) - cur_size = len(block_ids) - if skip_last_block_id: - cur_size -= 1 - - # Sanity checks - assert cur_size >= 0 - assert prev_prefix_size <= cur_size - - ret = prev_computed_block_ids - for i in range(prev_prefix_size, cur_size): - block_id = block_ids[i] - if self.block_is_computed(block_id): - ret.append(block_id) - return ret - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: """Return the block ids that are common for a given sequence group. @@ -634,6 +616,49 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + """ + Given a list of block hashes, return the prefix of the block hashes that + are all cached. + + Since a block's block hash includes the hashes of all previous blocks, + and we only allocate/deallocate blocks in the entire sequence, so if a + block is cached, then all previous blocks are also cached. With this + property, we can use binary search to find the prefix of cached blocks. + + Args: + block_hashes (List[int]): The list of block hashes. + + Returns: + List[int]: The prefix of the `block_hashes` that are cached. + """ + + def _block_is_cached(block_hash: PrefixHash) -> bool: + if block_hash not in self._cached_blocks: + return False + + cached_block_id = self._cached_blocks[block_hash] + # We only consider the blocks that are marked as computed. + return self.block_is_computed(cached_block_id) + + def _bisect_left(a, x, key) -> int: + import sys + from bisect import bisect_left + + # python <= 3.10 don't have the key argument + if sys.version_info < (3, 10): + a = [_block_is_cached(x) for x in a] + return bisect_left(a, x) + else: + return bisect_left(a, x, key=key) + + # Look for the first block that's not cached, and returns the prefix + # i.e. blocks that are cached. + idx = _bisect_left(block_hashes, + True, + key=lambda x: not _block_is_cached(x)) + return block_hashes[:idx] + class PrefixCachingBlock(Block): """A block implementation that supports prefix caching. @@ -843,86 +868,126 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], class ComputedBlocksTracker: - """Handles caching of per-sequence computed block ids. - When a sequence appears for the first time, it traverses all of the - blocks and detects the prefix of blocks that is computed. On the - subsequent times, it only traverses the new blocks that were added - and updates the already recorded prefix of blocks with the newly - computed blocks. - - To avoid redundant traversals, the algorithm also detects when there - is a "gap" in the computed prefix. For example, if we have blocks = - [1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then - we won't try to add more computed blocks to [1,2,3] in this sequence - iteration, and will add more computed blocks only after the sequence is - freed and reused again. - - Note that currently, for a given sequence, we also skip the last - block id for caching purposes, to avoid caching of a full sequence """ + Tracks the computed blocks for each sequence. - def __init__(self, allocator): - self._allocator = allocator - self._cached_computed_seq_blocks: Dict[int, Tuple[List[int], - bool]] = {} + Internally, it maintains a map from sequence id to the list of block hashes + for the sequence. We cache the hashes of the full blocks for each sequence, + and make sure the hash is calculated in the same way as the allocator. + When a sequence is being decoded, we also update the sequence's hash + accordingly and incrementally. - def add_seq(self, seq_id: int) -> None: - """Start tracking seq_id - """ - assert seq_id not in self._cached_computed_seq_blocks - self._cached_computed_seq_blocks[seq_id] = ([], False) - - def remove_seq(self, seq_id: int) -> None: - """Stop tracking seq_id - """ - assert seq_id in self._cached_computed_seq_blocks - del self._cached_computed_seq_blocks[seq_id] - - def get_cached_computed_blocks_and_update( - self, seq_id: int, block_ids: List[int]) -> List[int]: - """ Look at the class documentation for details - """ - # Ensure seq_id is already tracked - assert seq_id in self._cached_computed_seq_blocks - - # Get cached data (may be empty on the first time) - prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[ - seq_id] - - if has_gap: - # When gap is detected, we do not add more computed blocks at this - # sequence iteration - return prev_computed_block_ids - - # We do not consider the last block id for caching purposes. - num_cur_blocks = len(block_ids) - 1 - assert num_cur_blocks >= 0 - - if len(prev_computed_block_ids) >= num_cur_blocks: - # Cache HIT - assert len(prev_computed_block_ids) == num_cur_blocks - return prev_computed_block_ids - - # If here, then we may possibly add more computed blocks. As a result, - # traverse the additional blocks after prev_computed_block_ids to - # detect more computed blocks and add them. - - # Incremental init for seq_id => Look only at the new blocks - computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501 - prev_computed_block_ids, - block_ids, - skip_last_block_id= - True, # We skip last block id to avoid caching of full seq - ) + From the sequence hash, with prefix caching enabled, we could also calculate + the number of cached tokens for the sequence by looking up the number of + cached block hashes in the allocator. + """ - # Detect if there is a "gap" - has_gap = len(computed_block_ids) < num_cur_blocks + def __init__( + self, + allocator: DeviceAwareBlockAllocator, + block_size: int, + enable_caching: bool, + ): + self._allocator = allocator + self._block_size = block_size + self._enable_caching = enable_caching + + # A map from seq_id to the list of block hashes for the + # sequence. This is so that we don't have to recompute the block hashes + # for the sequence when we need to check if the sequence is cached. + # Note a block that's not full will not have its hash calculated and + # recorded. + self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {} + + # A map from seq_id to the number of tokens that are cached for the + # sequence. + # We need this so that a sequence in continuous prefill doesn't + # accidentally see its cached token count change. See comments in + # `get_num_cached_tokens` for more details. + self._seq_id_to_num_tokens_computed: Dict[int, int] = {} + + def _update_seq_hashes(self, seq: Sequence) -> None: + """Incrementally update the sequence's block hashes and record them.""" + assert self._enable_caching + + block_hashes_recorded = self._seq_id_to_blocks_hashes.get( + seq.seq_id, []) + cur_num_blocks_recorded = len(block_hashes_recorded) + token_ids = seq.get_token_ids() + assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, ( + f"The sequence has {len(token_ids)} tokens, but" + f" already recorded {cur_num_blocks_recorded} blocks. " + "This should not happen since we assume blocks are " + "only appended other than recomputation. When the sequence is " + "recomputed, we should have removed the info of the old blocks.") + # Update the computed block hashes for the sequence. Since only full + # blocks are considered as "computed", we take floor here. + num_computed_blocks = len(token_ids) // self._block_size + + # We need to know the hash of the previous block to compute the hash of + # the current block so that blocks could be uniquely identified across + # sequences of prefixes. + prev_block_hash = (None if cur_num_blocks_recorded == 0 else + block_hashes_recorded[-1]) + # Only update the computed block hashes for the new blocks + for i in range(cur_num_blocks_recorded, num_computed_blocks): + assert len(token_ids) >= (i + 1) * self._block_size + block_token_ids = token_ids[i * self._block_size:(i + 1) * + self._block_size] + # This has to be kept in sync with the allocator's hash + # calculation. + block_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block=prev_block_hash is None, + prev_block_hash=prev_block_hash, + cur_block_token_ids=block_token_ids, + ) + block_hashes_recorded.append(block_hash) + prev_block_hash = block_hash + + self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded + + def get_num_cached_tokens(self, seq: Sequence) -> int: + if not self._enable_caching: + return 0 + + # We always try to update the sequence hashes on the fly. + # This is to ensure that we don't miss any cached tokens for the + # sequence during decode. + # This routine should only update hash for any new blocks too. + self._update_seq_hashes(seq) + + num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get( + seq.seq_id, None) + + # TODO(rickyx): This hack could be removed once we mark blocks as + # computed correctly with chunked prefills. + if num_computed_tokens_prev is not None and seq.is_prefill(): + # For a sequence that is still in prefill, we don't + # recompute the number of cached tokens. + # This also handles correctly chunked prefill since currently + # we mark blocks as computed even if the sequence is still partially + # prefilled. So a continuously prefilled sequence should not + # see its cached token count change while running. + return num_computed_tokens_prev + + block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id] + + # This is O(logN), where N is the number of blocks. + num_cached_blocks = len( + self._allocator.find_cached_blocks_prefix(block_hashes)) + num_cached_tokens = num_cached_blocks * self._block_size + self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens + return num_cached_tokens - # Record - self._cached_computed_seq_blocks[seq_id] = (computed_block_ids, - has_gap) + def remove_seq(self, seq_id: int) -> None: + """Stop tracking the sequence.""" + if not self._enable_caching: + return + assert seq_id in self._seq_id_to_blocks_hashes + del self._seq_id_to_blocks_hashes[seq_id] - return computed_block_ids + assert seq_id in self._seq_id_to_num_tokens_computed + del self._seq_id_to_num_tokens_computed[seq_id] class LastAccessBlocksTracker: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 21f4c63b6572d..209487c6b4f9e 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -101,7 +101,7 @@ def __init__( self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self._computed_blocks_tracker = ComputedBlocksTracker( - self.block_allocator) + self.block_allocator, self.block_size, self.enable_caching) self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) @@ -170,7 +170,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id) # Assign the block table for each sequence. @@ -178,7 +177,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table.fork() # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id) # Allocate cross-attention block table for encoder sequence @@ -314,11 +312,13 @@ def get_common_computed_block_ids( """ computed_seq_block_ids = [] for seq in seqs: - computed_seq_block_ids.append( - self._computed_blocks_tracker. - get_cached_computed_blocks_and_update( - seq.seq_id, - self.block_tables[seq.seq_id].physical_block_ids)) + all_blocks = self.block_tables[seq.seq_id].physical_block_ids + num_cached_tokens = ( + self._computed_blocks_tracker.get_num_cached_tokens(seq)) + assert num_cached_tokens % self.block_size == 0 + num_cached_blocks = num_cached_tokens // self.block_size + computed_block_ids = all_blocks[:num_cached_blocks] + computed_seq_block_ids.append(computed_block_ids) # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( @@ -332,7 +332,6 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_tables[child_seq.seq_id] = src_block_table.fork() # Track child seq - self._computed_blocks_tracker.add_seq(child_seq.seq_id) self._last_access_blocks_tracker.add_seq(child_seq.seq_id) def can_swap_in(self, seq_group: SequenceGroup, @@ -503,3 +502,9 @@ def _can_swap(self, return AllocStatus.OK else: return AllocStatus.LATER + + def get_num_cached_tokens(self, seq: Sequence) -> int: + """Get the number of tokens in blocks that are already computed and + cached in the block manager for the sequence. + """ + return self._computed_blocks_tracker.get_num_cached_tokens(seq) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 9501a516bf020..b10b8d3f4a5bf 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -121,3 +121,7 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + + @abstractmethod + def get_num_cached_tokens(self, seq: Sequence) -> int: + pass diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index a337392bbed53..26d42b7f1790e 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -89,3 +89,6 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: return -1 + + def get_num_cached_tokens(self, seq: Sequence) -> int: + return 0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index af4671ec29be9..f1264e60daa43 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -56,11 +56,16 @@ class SchedulingBudget: max_num_seqs: int _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + # Number of cached tokens in the batch. + _num_cached_tokens: int = 0 + # Number of actual non-cached tokens in the batch. _num_batched_tokens: int = 0 _num_curr_seqs: int = 0 def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): - assert num_new_tokens != 0 + # We allow num_new_tokens to be 0 when the entire sequence has + # been cached. + assert num_new_tokens >= 0 assert num_new_seqs != 0 return (self.num_batched_tokens + num_new_tokens <= self.token_budget and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) @@ -68,12 +73,18 @@ def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): def remaining_token_budget(self): return self.token_budget - self.num_batched_tokens - def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + def add_num_batched_tokens(self, + req_id: str, + num_batched_tokens: int, + num_cached_tokens: int = 0): if req_id in self._request_ids_num_batched_tokens: return + assert num_cached_tokens >= 0 + assert num_batched_tokens >= 0 self._request_ids_num_batched_tokens.add(req_id) self._num_batched_tokens += num_batched_tokens + self._num_cached_tokens += num_cached_tokens def subtract_num_batched_tokens(self, req_id: str, num_batched_tokens: int): @@ -101,6 +112,10 @@ def num_batched_tokens(self): def num_curr_seqs(self): return self._num_curr_seqs + @property + def num_cached_tokens(self): + return self._num_cached_tokens + @dataclass class ScheduledSequenceGroup: @@ -541,9 +556,19 @@ def _schedule_running( assert len(self._async_stopped) == 0 while running_queue: seq_group = running_queue[0] - num_running_tokens = self._get_num_new_tokens( - seq_group, SequenceStatus.RUNNING, enable_chunking, budget) - + # We discard the cached tokens info here because we don't need it + # for running sequence: + # 1. If a sequence is running with chunked prefill, the cached + # tokens info was already used for the first prefill. + # 2. If a sequence is running with non-chunked prefill, then + # there it's a decoding sequence, and the cached tokens info is + # irrelevant. + num_uncached_new_tokens, _ = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.RUNNING, enable_chunking, + budget)) + + num_running_tokens = num_uncached_new_tokens if num_running_tokens == 0: # No budget => Stop break @@ -715,13 +740,16 @@ def _schedule_swapped( # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.SWAPPED, - enable_chunking, budget) - - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.SWAPPED, enable_chunking, + budget)) + + if (num_new_tokens_uncached + num_new_tokens_cached == 0 + or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): break if lora_int_id > 0 and curr_loras is not None: @@ -732,12 +760,19 @@ def _schedule_swapped( is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( - ScheduledSequenceGroup(seq_group, - token_chunk_size=num_new_tokens)) + ScheduledSequenceGroup( + seq_group, + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + )) else: decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) swapped_queue.extendleft(leftover_swapped) @@ -803,26 +838,30 @@ def _schedule_priority_preemption( if waiting_queue: seq_group = waiting_queue.popleft() num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - False, budget) + num_new_tokens_uncached, _ = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, False, budget)) #Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): #Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) - if (num_new_tokens and can_allocate == AllocStatus.OK - and budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + if (num_new_tokens_uncached > 0 + and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): break #Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() - num_running_tokens = self._get_num_new_tokens( - vseq_group, SequenceStatus.RUNNING, False, budget) - budget.subtract_num_batched_tokens(vseq_group.request_id, - num_running_tokens) + num_running_tokens_uncached, _ = ( + self._get_num_new_uncached_and_cached_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget)) + budget.subtract_num_batched_tokens( + vseq_group.request_id, num_running_tokens_uncached) num_running_seqs = vseq_group.get_max_num_running_seqs() budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) @@ -882,9 +921,12 @@ def _schedule_prefills( assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - enable_chunking, budget) + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, enable_chunking, + budget)) + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached + if not enable_chunking: num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens @@ -935,10 +977,18 @@ def _schedule_prefills( waiting_queue.popleft() continue + if (budget.num_batched_tokens >= + self.scheduler_config.max_num_batched_tokens): + # We've reached the budget limit - since there might be + # continuous prefills in the running queue, we should break + # to avoid scheduling any new prefills. + break + num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + if num_new_tokens == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): break # Can schedule this request. @@ -967,7 +1017,11 @@ def _schedule_prefills( seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) # Queue requests that couldn't be scheduled. @@ -1075,7 +1129,8 @@ def _schedule_default(self) -> SchedulerOutputs: return SchedulerOutputs( scheduled_seq_groups=scheduled_seq_groups, num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=blocks_to_copy, @@ -1119,7 +1174,6 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: running_scheduled.swapped_out) == 0: swapped_in = self._schedule_swapped(budget, curr_loras) - # Schedule new prefills. prefills = self._schedule_prefills(budget, curr_loras, enable_chunking=True) @@ -1157,7 +1211,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), - num_batched_tokens=budget.num_batched_tokens, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=running_scheduled.blocks_to_copy + @@ -1584,64 +1639,171 @@ def _get_num_lookahead_slots(self, is_prefill: bool, return self.scheduler_config.num_lookahead_slots - def _get_num_new_tokens(self, seq_group: SequenceGroup, - status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> int: - """Get the next new tokens to compute for a given sequence group - that's in a given `status`. + def _get_num_new_uncached_and_cached_tokens( + self, + seq_group: SequenceGroup, + status: SequenceStatus, + enable_chunking: bool, + budget: SchedulingBudget, + ) -> Tuple[int, int]: + """ + Returns the number of new uncached and cached tokens to schedule for a + given sequence group that's in a given `status`. The API could chunk the number of tokens to compute based on `budget` if `enable_chunking` is True. If a sequence group has multiple sequences (e.g., running beam search), it means it is in decoding phase, so chunking doesn't happen. - Returns 0 if the new token cannot be computed due to token budget. + Returns (0, 0) if the new token cannot be computed due to token budget. + + The cached tokens's blocks are already computed, and the attention + backend will reuse the cached blocks rather than recomputing them. So + the scheduler could schedule these cached tokens "for free". + + Args: + seq_group: The sequence group to get the number of new tokens to + schedule. + status: The status of the sequences to get the number of new tokens + to schedule. + enable_chunking: Whether to chunk the number of tokens to compute. + budget: The budget to chunk the number of tokens to compute. + + + Returns: + A tuple of two ints. The first int is the number of new uncached + tokens to schedule. The second int is the number of cached tokens. + If no more new tokens can be scheduled, returns (0, 0). """ - num_new_tokens = 0 + num_cached_new_tokens = 0 + num_uncached_new_tokens = 0 + seqs = seq_group.get_seqs(status=status) + # Compute the number of new uncached and cached tokens for + # each sequence. for seq in seqs: - num_new_tokens += seq.get_num_new_tokens() - assert num_new_tokens > 0 - # Chunk if a running request cannot fit in the given budget. - # If number of seq > 1, it means it is doing beam search - # in a decode phase. Do not chunk. + if not seq.is_prefill(): + # Decode sequences should always just have 1 uncached token + # TODO(rickyx): Actually is this still correct for multi-step? + num_uncached_new_tokens += 1 + continue + + num_computed_tokens_seq = seq.get_num_computed_tokens() + all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq + if not self.cache_config.enable_prefix_caching: + # If prefix caching is not enabled, all new tokens are uncached. + num_uncached_new_tokens += all_num_new_tokens_seq + continue + + # NOTE: the cache token might be currently in a block that's in an + # evictor meaning that it's not yet allocated. However, we don't + # exclude such tokens in the cache count because it will be + # guaranteed to be allocated later if the sequence can be allocated. + num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( + seq) + + # Sanity check. + if num_cached_tokens_seq < num_computed_tokens_seq: + # This should only happen with chunked prefill, and + # the seq is still in prefill. The `num_cached_tokens_seq` + # is the value we calculated on scheduling the first prefill. + # For subsequent continuous prefill steps, we cached the + # number of cache tokens for the sequence so the cached token + # count could be less than the number of computed tokens. + # See comments on `ComputedBlocksTracker` for more details. + assert ( + seq.is_prefill() and seq.status == SequenceStatus.RUNNING + and self.scheduler_config.chunked_prefill_enabled + ), ("Number of cached tokens should not be less than the " + "number of computed tokens for a sequence that's still " + f"in prefill. But there are {num_cached_tokens_seq} cached " + f"tokens and {num_computed_tokens_seq} computed tokens " + f"for sequence {seq.seq_id}.") + + num_cached_new_tokens_seq = max( + 0, num_cached_tokens_seq - num_computed_tokens_seq) + num_uncached_new_tokens_seq = (all_num_new_tokens_seq - + num_cached_new_tokens_seq) + + num_uncached_new_tokens += num_uncached_new_tokens_seq + num_cached_new_tokens += num_cached_new_tokens_seq + if enable_chunking and len(seqs) == 1: - remaining_token_budget = budget.remaining_token_budget() - if self.scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - # - # Therefore, `num_new_tokens` is computed in the same fashion - # for both multi-step+chunked-prefill & - # multi-step+chunked-prefill+APC - # - # Prompts with more tokens than the current remaining budget - # are postponed to future scheduler steps - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens - elif self.cache_config.enable_prefix_caching: - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block - # size to avoid partial block matching. - block_size = self.cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size - else: - num_new_tokens = min(num_new_tokens, remaining_token_budget) + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( + self.scheduler_config, + self.cache_config, + budget, + self._get_prompt_limit(seq_group), + num_uncached_new_tokens, + ) + + return num_uncached_new_tokens, num_cached_new_tokens + + @staticmethod + def _chunk_new_tokens_to_schedule( + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + budget: SchedulingBudget, + prompt_limit: int, + num_new_tokens: int, + ) -> int: + """ + Chunks the number of new tokens to schedule based on the budget when + chunked prefill is enabled. + + Args: + scheduler_config: The scheduler config. + cache_config: The cache config. + budget: The budget to chunk the number of tokens to compute. + prompt_limit: The maximum number of tokens allowed in a prompt. + num_new_tokens: The number of new tokens to schedule. + + Returns: + The number of new tokens to schedule after chunking. + """ + remaining_token_budget = budget.remaining_token_budget() + if scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > prompt_limit: + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + return num_new_tokens + + return (0 if num_new_tokens > remaining_token_budget else + num_new_tokens) + + if cache_config.enable_prefix_caching: + # Adjust the remaining token budget to be divisible by the block + # size when prefix caching is enabled. + + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. + block_size = cache_config.block_size + remainder = budget.token_budget % block_size + if remainder != 0: + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}") + # Round down to block size. + remaining_token_budget = (remaining_token_budget // block_size * + block_size) + + num_new_tokens = min(num_new_tokens, remaining_token_budget) + return num_new_tokens diff --git a/vllm/sequence.py b/vllm/sequence.py index 3b41d25a2fe42..a1cc8fc3b09de 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -579,6 +579,9 @@ def get_num_new_tokens(self) -> int: return 1 return self.data.get_num_uncomputed_tokens() + def get_num_computed_tokens(self) -> int: + return self.data.get_num_computed_tokens() + def is_prefill(self) -> bool: return self.data.stage == SequenceStage.PREFILL From b232a45ebc97710acec6a0c8f5f2f46b3284827f Mon Sep 17 00:00:00 2001 From: rickyx Date: Fri, 15 Nov 2024 07:31:35 +0000 Subject: [PATCH 2/4] fix Signed-off-by: rickyx --- tests/core/utils.py | 31 ++++++- tests/prefix_caching/test_prefix_caching.py | 90 +++++++++++++++++++++ vllm/core/scheduler.py | 18 +++-- 3 files changed, 131 insertions(+), 8 deletions(-) diff --git a/tests/core/utils.py b/tests/core/utils.py index b2e99d55110ff..277368b57b938 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,12 +1,15 @@ import time -from typing import List, Optional +from collections import defaultdict +from typing import Any, Dict, List, Optional from typing import Sequence as GenericSequence from typing import Tuple from vllm import SamplingParams +from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.inputs import EncoderDecoderInputs, token_inputs from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob, Sequence, SequenceGroup +from vllm.sequence import (Logprob, Sequence, SequenceGroup, + SequenceGroupMetadata) def create_dummy_prompt( @@ -217,3 +220,27 @@ def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): seq_group.update_num_computed_tokens(token_chunk_size) for seq in seq_group.get_seqs(): seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + +class SchedulerProxy: + """ + A proxy class to forward calls to the scheduler. + """ + + def __init__(self, scheduler: Scheduler): + self.scheduler_ = scheduler + self.call_history: Dict[str, List[Any]] = defaultdict(list) + + def __getattr__(self, name: str) -> Any: + + def wrapper(*args, **kwargs): + result = getattr(self.scheduler_, name)(*args, **kwargs) + self.call_history[name].append((args, kwargs, result)) + return result + + return wrapper + + def last_schedule_ret( + self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]: + _, _, ret = self.call_history["schedule"][-1] + return ret diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 6ed3305bb6665..84eee146df821 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -7,8 +7,12 @@ import pytest +from tests.conftest import VllmRunner +from tests.core.utils import SchedulerProxy, create_dummy_prompt from tests.kernels.utils import override_backend_env_variable from vllm import SamplingParams, TokensPrompt +from vllm.core.scheduler import Scheduler +from vllm.engine.llm_engine import LLMEngine from vllm.transformers_utils.tokenizer import get_tokenizer from ..models.utils import check_outputs_equal @@ -188,3 +192,89 @@ def test_unstable_prompt_sequence( for prompt in UNSTABLE_PROMPT_SEQUENCE: vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), SamplingParams(max_tokens=1)) + + +@pytest.mark.parametrize("model", MODELS) +def test_fully_cached_prefill_needs_uncached_token(model): + block_size = 16 + max_num_batched_tokens = 16 + num_output_tokens = 5 + # Make a vllm engine + runner = VllmRunner( + model_name=model, + gpu_memory_utilization=0.7, + enable_chunked_prefill=True, + enforce_eager=True, + enable_prefix_caching=True, + block_size=block_size, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_batched_tokens, + ) + engine: LLMEngine = runner.model.llm_engine + + scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore + engine.scheduler[0] = scheduler + + # SeqA + seqA_tokens = list(range(2 * block_size)) + seqA, seq_groupA = create_dummy_prompt( + request_id="0", + prompt_tokens=seqA_tokens, + max_tokens=num_output_tokens, + block_size=block_size, + ) + + scheduler.add_seq_group(seq_groupA) + + assert seqA.data.get_num_computed_tokens() == 0 + + # Prefill seqA + while not seqA.is_finished(): + engine.step() + + # seqB + seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1 + seqB, seq_groupB = create_dummy_prompt( + request_id="1", + prompt_tokens=seqB_tokens, + max_tokens=num_output_tokens, + block_size=block_size, + ) + + # seqC is the same as seqA + seqC, seq_groupC = create_dummy_prompt( + request_id="2", + prompt_tokens=seqA_tokens, + max_tokens=num_output_tokens, + block_size=block_size, + ) + + scheduler.add_seq_group(seq_groupB) + scheduler.add_seq_group(seq_groupC) + + # Even seqC is fully cached, it should not be prefilled since we + # require at least 1 uncached token. + engine.step() + + sched_metas, sched_out, _ = scheduler.last_schedule_ret() + assert len(sched_out.scheduled_seq_groups) == 1 + assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == + seq_groupB.request_id) + assert (sched_out.scheduled_seq_groups[0].token_chunk_size == + max_num_batched_tokens) + + # When seqB is finished, seqC could be prefilled. + while not seqB.is_finished(): + engine.step() + sched_metas, sched_out, _ = scheduler.last_schedule_ret() + assert len(sched_out.scheduled_seq_groups) == 1 + assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == + seq_groupB.request_id) + + engine.step() + sched_metas, sched_out, _ = scheduler.last_schedule_ret() + assert len(sched_out.scheduled_seq_groups) == 1 + assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == + seq_groupC.request_id) + assert sched_out.scheduled_seq_groups[0].token_chunk_size == len( + seqA_tokens) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f1264e60daa43..841e65c488fc6 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -745,11 +745,10 @@ def _schedule_swapped( seq_group, SequenceStatus.SWAPPED, enable_chunking, budget)) - if (num_new_tokens_uncached + num_new_tokens_cached == 0 - or not budget.can_schedule( - num_new_tokens=num_new_tokens_uncached, - num_new_seqs=num_new_seqs, - )): + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): break if lora_int_id > 0 and curr_loras is not None: @@ -985,7 +984,7 @@ def _schedule_prefills( break num_new_seqs = seq_group.get_max_num_running_seqs() - if num_new_tokens == 0 or not budget.can_schedule( + if num_new_tokens_uncached == 0 or not budget.can_schedule( num_new_tokens=num_new_tokens_uncached, num_new_seqs=num_new_seqs, ): @@ -1728,6 +1727,13 @@ def _get_num_new_uncached_and_cached_tokens( num_uncached_new_tokens += num_uncached_new_tokens_seq num_cached_new_tokens += num_cached_new_tokens_seq + if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: + # For a fully cached hit sequence, we actually need to recompute the + # last token. So we need at least 1 uncached token to schedule. + # See ModelRunner._compute_for_prefix_cache_hit for more details. + num_uncached_new_tokens = 1 + num_cached_new_tokens -= 1 + if enable_chunking and len(seqs) == 1: # Chunk if a running request cannot fit in the given budget. # If number of seq > 1, it means it is doing beam search From 8c38100a910eb75b0330781a7aa186f340ca1f5b Mon Sep 17 00:00:00 2001 From: rickyx Date: Fri, 15 Nov 2024 19:38:40 +0000 Subject: [PATCH 3/4] fix Signed-off-by: rickyx --- vllm/core/block/prefix_caching_block.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index df3ef393136c8..b736167f6ceb4 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,6 +1,9 @@ """Token blocks.""" +import sys +from bisect import bisect_left from os.path import commonprefix -from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple +from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, + Tuple) from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) @@ -641,13 +644,11 @@ def _block_is_cached(block_hash: PrefixHash) -> bool: # We only consider the blocks that are marked as computed. return self.block_is_computed(cached_block_id) - def _bisect_left(a, x, key) -> int: - import sys - from bisect import bisect_left + def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int: # python <= 3.10 don't have the key argument if sys.version_info < (3, 10): - a = [_block_is_cached(x) for x in a] + a = [key(e) for e in a] return bisect_left(a, x) else: return bisect_left(a, x, key=key) From 949cb0f138df1e16031c2d25f3312b4e3243b5ac Mon Sep 17 00:00:00 2001 From: rickyx Date: Fri, 15 Nov 2024 21:59:29 +0000 Subject: [PATCH 4/4] remove flaky tests Signed-off-by: rickyx --- tests/prefix_caching/test_prefix_caching.py | 91 ++------------------- 1 file changed, 6 insertions(+), 85 deletions(-) diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 84eee146df821..8d16710f14585 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -2,8 +2,6 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ -import random -from typing import List, Optional import pytest @@ -13,7 +11,6 @@ from vllm import SamplingParams, TokensPrompt from vllm.core.scheduler import Scheduler from vllm.engine.llm_engine import LLMEngine -from vllm.transformers_utils.tokenizer import get_tokenizer from ..models.utils import check_outputs_equal @@ -83,13 +80,13 @@ def test_mixed_requests( block_size) * block_size else: expected_num_cached_tokens = 0 - assert req_outputs[ - i].num_cached_tokens == expected_num_cached_tokens + assert ( + req_outputs[i].num_cached_tokens == expected_num_cached_tokens) - vllm_outputs = [ - (output.prompt_token_ids + list(output.outputs[0].token_ids), - output.prompt + output.outputs[0].text) for output in req_outputs - ] + vllm_outputs = [( + output.prompt_token_ids + list(output.outputs[0].token_ids), + output.prompt + output.outputs[0].text, + ) for output in req_outputs] check_outputs_equal( outputs_0_lst=hf_outputs, @@ -99,82 +96,6 @@ def test_mixed_requests( ) -@pytest.mark.parametrize("chunk_size", [None, 64]) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("prefix_ratio", [0.0, 0.5, 1.0]) -def test_partially_cached_prefills( - vllm_runner, - model: str, - chunk_size: Optional[int], - prefix_ratio: float, -) -> None: - """ """ - # Prompts with various lengths. - prompt_lens = [16, 32, 63, 64, 65, 127, 128, 256, 512] - - # NOTE(rickyx): We could potentially run against HF here and with a longer - # decode len, but seems there's quite some invariance in the generated - # tokens, i.e. decoded tokens might mismatch indeterministically across - # runs. There's even mismatch between non-prefix-cached enabled and HF - # ones. - max_tokens = 1 - max_model_len = 2 * max(prompt_lens) - if chunk_size is None: - max_num_batched_tokens = max_model_len - max_num_seqs = len(prompt_lens) - else: - max_num_batched_tokens = chunk_size - max_num_seqs = chunk_size - - tokenizer = get_tokenizer(model, trust_remote_code=True) - random.seed(0) - vocab = list(tokenizer.get_vocab().values()) - - batch_prompts: List[List[str]] = [] - for prompt_len in prompt_lens: - num_prefix_tokens = int(prefix_ratio * prompt_len) - num_unique_tokens = prompt_len - num_prefix_tokens - - prefix_token_ids = random.choices(vocab, k=num_prefix_tokens) - unique_token_ids = random.choices(vocab, k=num_unique_tokens) - - prompt_token_ids = prefix_token_ids + unique_token_ids - prompt_str = tokenizer.decode(prompt_token_ids) - - # First 2 prompts might not have prefix cache hit even with shared - # prefix cache since not yet computed. - batch_prompts.append([prompt_str] * 2) - - # Next prompts should be fully cached, guaranteed to have prefix cache - # hit. - batch_prompts.append([prompt_str]) - - dtype = "half" - outputs = {} # type: ignore - for enable_prefix_caching in (True, False): - with vllm_runner( - model, - dtype=dtype, - enable_prefix_caching=enable_prefix_caching, - enable_chunked_prefill=chunk_size is not None, - max_num_batched_tokens=max_num_batched_tokens, - max_num_seqs=max_num_seqs, - max_model_len=max_model_len, - enforce_eager=True, - ) as vllm_model: - outputs[enable_prefix_caching] = [] - for prompts in batch_prompts: - outputs[enable_prefix_caching] += vllm_model.generate_greedy( - prompts, max_tokens=max_tokens) - - check_outputs_equal( - outputs_0_lst=outputs[False], - outputs_1_lst=outputs[True], - name_0="vllm_no_apc", - name_1="vllm", - ) - - @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) def test_unstable_prompt_sequence( vllm_runner,