Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix Cache Aware Scheduling [1/n] #10128

Merged
merged 4 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 175 additions & 6 deletions tests/core/block/test_prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@

import pytest

from tests.core.utils import create_dummy_sequence
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.prefix_caching_block import (PrefixCachingBlock,
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
PrefixCachingBlock,
PrefixCachingBlockAllocator)
from vllm.sequence import Logprob
from vllm.utils import Device


class TestPrefixCachingBlock:
Expand Down Expand Up @@ -726,18 +731,71 @@ def test_touch_block():
token_ids=common_token_ids,
allocator=allocator,
)
block_ids = [block.block_id for block in blocks]
block_hashes = [block.content_hash for block in blocks]
# The allocated blocks should be marked as touched
# but not computed.
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
computed_block_ids = allocator.find_cached_blocks_prefix(
block_hashes)
assert len(computed_block_ids) == 0

allocator.mark_blocks_as_computed([])
computed_block_ids = allocator.get_computed_block_ids(
[], block_ids, skip_last_block_id=False)
computed_block_ids = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes)
assert len(computed_block_ids) == common_blocks

@staticmethod
def test_find_cached_blocks_prefix():
"""
This test verifies the behavior of find_cached_blocks_prefix.
"""
block_size = 4
num_blocks = 8
total_test_blocks = 12
allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks,
block_size=block_size)

token_ids = list(range(total_test_blocks * block_size))
block_tokens_seq1 = token_ids[:num_blocks * block_size]
blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=block_tokens_seq1,
allocator=allocator,
)
block_hashes_seq1 = [block.content_hash for block in blocks_seq1]
allocator.mark_blocks_as_computed([])

# All blocks should be cached.
cached_blocks_seq1 = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks_seq1) == num_blocks

# Free the first sequence.
for block in blocks_seq1:
allocator.free(block)

# All blocks should be still be cached if not required to be allocated.
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == num_blocks

block_tokens_seq2 = token_ids[num_blocks * block_size:]
blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=block_tokens_seq2,
allocator=allocator,
)
block_hashes_seq2 = [block.content_hash for block in blocks_seq2]
allocator.mark_blocks_as_computed([])
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq2)
assert len(cached_blocks) == len(blocks_seq2)

# Half of the blocks from seq1 should still be cached.
num_evicted_blocks = len(blocks_seq2)
cached_blocks = allocator.find_cached_blocks_prefix(
block_hashes=block_hashes_seq1)
assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks

@staticmethod
def create_immutable_chain(
block_size: int,
Expand All @@ -762,3 +820,114 @@ def create_immutable_chain(
blocks.append(prev_block)

return blocks


class TestComputedBlocksTracker:

@staticmethod
def _get_mock_allocator():
return MagicMock(spec=PrefixCachingBlockAllocator)

@staticmethod
def test_get_num_cached_tokens():
"""
Test it correctly computes the number of cached tokens for a given
sequence:

- The cache token count is derived from the number of cached blocks.
- The cache token count is updated when the allocator is updated.
- When a sequence is removed, the cache token count should be updated
accordingly.

# TODO(rickyx): This behaviour for prefill sequence is a hack until
we fix the computed blocks tracking.
- The cache token count for prefill sequence doesn't change while
the sequence is in continuous prefill (chunked prefill).
"""
block_size = 4
mock_allocator = TestComputedBlocksTracker._get_mock_allocator()
tracker = ComputedBlocksTracker(
allocator=mock_allocator,
block_size=block_size,
enable_caching=True,
)

# Not yet allocated.
tokens = [0, 1, 2, 3, 4, 5]
seq1 = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
mock_allocator.find_cached_blocks_prefix.return_value = []
assert tracker.get_num_cached_tokens(seq1) == 0

mock_allocator.find_cached_blocks_prefix.return_value = [
None
] # 1 block cached.
# Result is cached for prefill sequence.
assert tracker.get_num_cached_tokens(seq1) == 0

# Mark the sequence as non-prefill.
seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed.
assert not seq1.is_prefill()

# Recomputes for decoding sequence.
assert tracker.get_num_cached_tokens(seq1) == 4

# Append new tokens to the sequence.
num_new_tokens = 3
for i in range(num_new_tokens):
seq1.append_token_id(i, {i: Logprob(logprob=0.0)})

assert tracker.get_num_cached_tokens(seq1) == 4

# Update the allocator.
mock_allocator.find_cached_blocks_prefix.return_value = [
None
] * 2 # 2 blocks cached.
assert tracker.get_num_cached_tokens(seq1) == 8

# Remove the sequence.
tracker.remove_seq(seq1.seq_id)

# Re-create the sequence with the same request id to simulate recompute.
seq1 = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
mock_allocator.find_cached_blocks_prefix.return_value = [
] # no cached block
assert tracker.get_num_cached_tokens(seq1) == 0

@staticmethod
def test_correct_block_hash():
"""
Test that the block hash is correctly computed for a sequence (should
match the underlying block allocator's block hash). So the number of
cached tokens is correctly retrieved.
"""
block_size = 4
allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching",
num_gpu_blocks=16,
num_cpu_blocks=16,
block_size=block_size,
)
gpu_allocator = allocator._allocators[Device.GPU]

tracker = ComputedBlocksTracker(
allocator=allocator,
block_size=block_size,
enable_caching=True,
)

tokens = list(range(block_size * 4)) # 4 blocks.
seq = create_dummy_sequence(request_id=0,
token_ids=tokens,
block_size=block_size)
_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=tokens,
allocator=gpu_allocator,
)
allocator.mark_blocks_as_computed([])

assert tracker.get_num_cached_tokens(seq) == len(tokens)
Loading