Skip to content

Commit

Permalink
feat: add support for chunked prefill + prefix caching (#871)
Browse files Browse the repository at this point in the history
* feat: add support for chunked prefill + prefix caching

* update tests
  • Loading branch information
AlpinDale authored Dec 7, 2024
1 parent ef99a56 commit abfd446
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 26 deletions.
19 changes: 13 additions & 6 deletions aphrodite/processing/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,20 @@ def access_all_blocks_in_seq(
for block in block_table:
block.last_accessed = access_time

def compute_full_blocks_in_seq(self, seq: Sequence):
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1

# When chunked prefill is enabled, the computed full blocks
# should be calculated based on the number of computed tokens.
max_computed_tokens = (seq.data.get_num_computed_tokens() +
token_chunk_size)
computed_full_blocks = max_computed_tokens // self.block_size

block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
if computed_full_blocks == 0:
return
for i in reversed(range(max_full_block)):
for i in reversed(range(computed_full_blocks)):
if block_table[i].computed:
break
block_table[i].computed = True
Expand Down Expand Up @@ -717,10 +723,11 @@ def get_common_computed_block_ids(
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
self.compute_full_blocks_in_seq(seq, token_chunk_size)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/processing/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float):
self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now)

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
# The only need for mark block as computed is for prefix caching,
# while currently we could determine whether one block is computed
# or not by check whether it has content hash.
Expand Down
3 changes: 2 additions & 1 deletion aphrodite/processing/placeholder_block_space_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass

def get_prefix_cache_hit_rate(self, device: Device) -> float:
Expand Down
30 changes: 24 additions & 6 deletions aphrodite/processing/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# will crash the Aphrodite instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
scheduled_seq_group.seq_group,
scheduled_seq_group.token_chunk_size)

return seq_group_metadata_list, scheduler_outputs

Expand Down Expand Up @@ -1344,10 +1345,27 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup,
for seq in seqs:
num_new_tokens += seq.get_num_new_tokens()
assert num_new_tokens > 0
# Chunk if a running request cannot fit in.
# If number of seq > 1, it means it is doing beam search in a
# decode phase. Do not chunk in that case.
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
if enable_chunking and len(seqs) == 1:
num_new_tokens = min(num_new_tokens,
budget.remaining_token_budget())
remaining_token_budget = budget.remaining_token_budget()
if self.cache_config.enable_prefix_caching:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block size
# to avoid partial block matching.
block_size = self.cache_config.block_size
reminder = budget.token_budget % block_size
if reminder != 0:
raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f"({budget.token_budget}) % block_size "
f"({block_size}) = {reminder}")
if remaining_token_budget < num_new_tokens:
num_new_tokens = (remaining_token_budget //
block_size) * block_size
else:
num_new_tokens = min(num_new_tokens, remaining_token_budget)
return num_new_tokens
49 changes: 37 additions & 12 deletions aphrodite/task_handler/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,23 +499,48 @@ def _compute_for_prefix_cache_hit(
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")

# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size

if not prefix_cache_hit:
return

assert computed_block_nums is not None
# The cache hit prompt tokens in this sequence. Note that
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len = len(computed_block_nums) * self.block_size
# The number of so far computed prompt tokens in this sequence.
context_len = inter_data.context_lens[seq_idx]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len:
# We already passed the cache hit region,
# so do normal computation.
pass
elif context_len < prefix_cache_len < seq_len:
# Partial hit. Compute the missing part.
uncomputed_start = prefix_cache_len - context_len
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
seq_idx][uncomputed_start:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]
seq_idx][uncomputed_start:]
context_len = prefix_cache_len

inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
elif seq_len <= prefix_cache_len:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][-1:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][-1:]
inter_data.query_lens[seq_idx] = 1
inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1

def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup,
seq_idx: int,
Expand Down
66 changes: 66 additions & 0 deletions tests/basic_correctness/test_chunked_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Run `pytest tests/models/test_chunked_prefill.py`.
"""
from contextlib import nullcontext

import pytest

Expand Down Expand Up @@ -151,3 +152,68 @@ def test_models_with_fp8_kv_cache(
name_0="no_chunked_prefill",
name_1="chunked_prefill",
)


@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
def test_with_prefix_caching(
aphrodite_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
use_v2_block_manager: bool,
tensor_parallel_size: int,
) -> None:
"""
Checks exact match decode with and without prefix caching
with chunked prefill enabled.
"""
model = "meta-llama/Llama-2-7b-chat-hf"
# The common prompt has 142 tokens with Llama-2 tokenizer.
common_prompt = "You are a helpful AI assistant " * 20
unique_prompts = [
"Question", # Warmup
"Question", # Fully cached
"Another question", # Partial cached
]
full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts]

max_num_batched_tokens = max_num_seqs = chunk_size
outputs = {} # type: ignore
check_result = True
for enable in (True, False):
with aphrodite_runner(
model,
dtype="half",
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True,
enable_prefix_caching=enable,
tensor_parallel_size=tensor_parallel_size,
use_v2_block_manager=use_v2_block_manager,
enforce_eager=enforce_eager,
max_num_seqs=max_num_seqs,
) as aphrodite_model:
# It should fail when prefix caching is enable and chunk
# size is not a multiple of block size (16).
should_fail = chunk_size % 16 != 0 and enable
check_result &= not should_fail
outputs[enable] = []
# Send the request one-by-one to ensure the cache is populated.
with pytest.raises(ValueError) if should_fail else nullcontext():
for prompt in full_prompts:
outputs[enable] += aphrodite_model.generate_greedy(
[prompt], max_tokens)

# Check results only if we did not expect a failure.
if check_result:
check_outputs_equal(
outputs_0_lst=outputs[False],
outputs_1_lst=outputs[True],
name_0="w/o prefix caching",
name_1="with prefix caching",
)
40 changes: 40 additions & 0 deletions tests/core/test_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,3 +596,43 @@ def test_sliding_window_multi_seq():

# assert all blocks are free now
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks


def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill():
"""When prefix cache and chunked prefill are enabled, the block manager
should only mark a chunk of blocks as computed instead of all blocks.
"""

block_size = 4
num_cpu_blocks = 0
num_gpu_blocks = 16
block_manager = BlockSpaceManagerV1(block_size,
num_gpu_blocks,
num_cpu_blocks,
watermark=0,
enable_caching=True)

# Set prompt size to have num_gpu_blocks - 1 full blocks.
prompt_length = block_size * num_gpu_blocks - 1

# Allocate (reserve) all blocks.
_, seq_group = create_dummy_prompt("0",
prompt_length,
block_size=block_size)
block_manager.allocate(seq_group)
assert seq_group.seqs[0].n_blocks == num_gpu_blocks

# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
token_chunk_size = int(block_size * 2.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 2

# Actual computed tokens.
seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size)

# 2nd chunk: Complete 3rd block and additional 4 blocks.
token_chunk_size = int(block_size * 4.5)
block_manager.mark_blocks_as_computed(seq_group, token_chunk_size)
computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0])
assert len(computed_blocks) == 7
39 changes: 39 additions & 0 deletions tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,42 @@ def test_chunked_prefill_max_seqs():
assert len(get_sequence_groups(out)) == max_seqs
assert not running[0].is_prefill()
assert not running[1].is_prefill()


def test_perfix_caching():
"""Verify allocating full blocks when prefix caching is enabled."""
block_size = 4
max_seqs = 10
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
cache_config = CacheConfig(block_size,
1.0,
1,
"auto",
enable_prefix_caching=True)
cache_config.num_cpu_blocks = 0
cache_config.num_gpu_blocks = 32
scheduler = Scheduler(scheduler_config, cache_config, None)
running: List[SequenceGroup] = []

# Add seq groups to scheduler.
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
block_size=block_size,
prompt_length=50)
scheduler.add_seq_group(seq_group)
running.append(seq_group)

seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert set(get_sequence_groups(out)) == set(running)
assert seq_group_meta[0].token_chunk_size == 50
# Verify it is chunked. Note that although the budget is 64-50=14,
# we only allocate full blocks for prefix caching, so only 4*(14//4)=12
# tokens are allocated.
assert seq_group_meta[1].token_chunk_size == 12
assert out.num_prefill_groups == 2
assert out.num_batched_tokens == 62

0 comments on commit abfd446

Please sign in to comment.