Skip to content

Commit

Permalink
[Performance] Introducing Prefix-Cached Chunked Prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
Juelianqvq committed Jul 26, 2024
1 parent 85ad7e2 commit 40a9f31
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
5 changes: 5 additions & 0 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,11 @@ def compute_full_blocks_in_seq(self, seq: Sequence):
if max_full_block == -1:
return
for i in reversed(range(max_full_block)):
# [help wanted]
# max_full_block < block_table makes sense, but combining pc + cp may produce a conflict,
# (do not know why) so following 'if' statement is needed, little hurt for performance
if i >= len(block_table):
continue
if block_table[i].computed:
break
block_table[i].computed = True
Expand Down
35 changes: 27 additions & 8 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def __init__(

self.multi_modal_inputs = multi_modal_inputs
self.prefix_cache_hit = prefix_cache_hit

# maybe dirty hack
self.cached_len = 0

self.__post_init__()

Expand Down Expand Up @@ -365,20 +368,36 @@ def _compute_for_prefix_cache_hit(
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")
# if self.chunked_prefill_enabled and prefix_cache_hit:
# raise RuntimeError(
# "chunked prefill cannot be used with prefix caching now.")

# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]

# [help wanted]
# inter_data.cached_len is just a work-around for mutable cached length when chunked prefill
# enabled with prefix caching, in order to fix the cached length for the same seq being chunked
context_len = inter_data.context_lens[seq_idx]
if context_len == 0:
inter_data.cached_len = len(computed_block_nums) * self.block_size
context_len = min(inter_data.cached_len, seq_group_metadata.token_chunk_size - 1)
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]
else:
if inter_data.cached_len > context_len:
delta_len = min(inter_data.cached_len - context_len, seq_group_metadata.token_chunk_size - 1)
context_len += delta_len
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][delta_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][delta_len:]

inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
Expand Down

0 comments on commit 40a9f31

Please sign in to comment.