From 40a9f312ec5de0fea39062f19346d96ee2648324 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pastel=EF=BC=81?= <1627301104@qq.com> Date: Fri, 26 Jul 2024 05:09:05 +0000 Subject: [PATCH] [Performance] Introducing Prefix-Cached Chunked Prefill --- vllm/core/block_manager_v1.py | 5 +++++ vllm/worker/model_runner.py | 35 +++++++++++++++++++++++++++-------- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e29eba375f4dd..d5af2cbfefea7 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -669,6 +669,11 @@ def compute_full_blocks_in_seq(self, seq: Sequence): if max_full_block == -1: return for i in reversed(range(max_full_block)): + # [help wanted] + # max_full_block < block_table makes sense, but combining pc + cp may produce a conflict, + # (do not know why) so following 'if' statement is needed, little hurt for performance + if i >= len(block_table): + continue if block_table[i].computed: break block_table[i].computed = True diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 86d26b4a84c36..6e0d33ce28190 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -246,6 +246,9 @@ def __init__( self.multi_modal_inputs = multi_modal_inputs self.prefix_cache_hit = prefix_cache_hit + + # maybe dirty hack + self.cached_len = 0 self.__post_init__() @@ -365,20 +368,36 @@ def _compute_for_prefix_cache_hit( and self.sliding_window is None and inter_data.is_prompt) inter_data.prefix_cache_hit = prefix_cache_hit - if self.chunked_prefill_enabled and prefix_cache_hit: - raise RuntimeError( - "chunked prefill cannot be used with prefix caching now.") + # if self.chunked_prefill_enabled and prefix_cache_hit: + # raise RuntimeError( + # "chunked prefill cannot be used with prefix caching now.") # If prefix cache is hit, advance context length to bypass # hit blocks. Accordingly, input tokens, position and query length # have to be updated. if prefix_cache_hit: assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size - inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][context_len:] - inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][context_len:] + + # [help wanted] + # inter_data.cached_len is just a work-around for mutable cached length when chunked prefill + # enabled with prefix caching, in order to fix the cached length for the same seq being chunked + context_len = inter_data.context_lens[seq_idx] + if context_len == 0: + inter_data.cached_len = len(computed_block_nums) * self.block_size + context_len = min(inter_data.cached_len, seq_group_metadata.token_chunk_size - 1) + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][context_len:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][context_len:] + else: + if inter_data.cached_len > context_len: + delta_len = min(inter_data.cached_len - context_len, seq_group_metadata.token_chunk_size - 1) + context_len += delta_len + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][delta_len:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][delta_len:] + inter_data.context_lens[seq_idx] = context_len inter_data.query_lens[ seq_idx] = inter_data.seq_lens[seq_idx] - context_len