Skip to content

Commit

Permalink
[Core] support LoRA and prompt adapter in content-based hashing for B…
Browse files Browse the repository at this point in the history
…lock Manager v2 prefix caching (vllm-project#8240)
  • Loading branch information
llsj14 authored Dec 13, 2024
1 parent d1fa714 commit c31d4a5
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 55 deletions.
65 changes: 63 additions & 2 deletions tests/core/block/test_prefix_caching_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from tests.core.utils import create_dummy_sequence
from tests.core.utils import create_dummy_lora_sequence, create_dummy_sequence
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
from vllm.core.block.interfaces import Block, BlockAllocator
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
Expand Down Expand Up @@ -801,6 +801,7 @@ def create_immutable_chain(
block_size: int,
token_ids: List[int],
allocator: PrefixCachingBlockAllocator,
extra_hash: Optional[int] = None,
) -> List[PrefixCachingBlock]:
"""Helper method which creates a chain of blocks.
"""
Expand All @@ -816,7 +817,9 @@ def create_immutable_chain(
block_size:(block_number + 1) *
block_size]
prev_block = allocator.allocate_immutable_block(
prev_block=prev_block, token_ids=block_token_ids)
prev_block=prev_block,
token_ids=block_token_ids,
extra_hash=extra_hash)
blocks.append(prev_block)

return blocks
Expand Down Expand Up @@ -931,3 +934,61 @@ def test_correct_block_hash():
allocator.mark_blocks_as_computed([])

assert tracker.get_num_cached_tokens(seq) == len(tokens)

@staticmethod
def test_correct_extra_hash():
"""
Test that the block hash is correctly computed based on the extra hash,
ensuring it matches the allocator's block hash, specifically for the
LoRA case, and that the correct number of cached tokens is retrieved.
"""
block_size = 4
allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching",
num_gpu_blocks=16,
num_cpu_blocks=16,
block_size=block_size,
)
gpu_allocator = allocator._allocators[Device.GPU]

tracker = ComputedBlocksTracker(
allocator=allocator,
block_size=block_size,
enable_caching=True,
)

tokens = list(range(block_size * 4))

# Create a dummy LoRA sequence with a specific LoRA ID.
lora_seq = create_dummy_lora_sequence(request_id=0,
token_ids=tokens,
block_size=block_size,
lora_int_id=1)

_ = TestPrefixCachingBlockAllocator.create_immutable_chain(
block_size=block_size,
token_ids=tokens,
allocator=gpu_allocator,
extra_hash=lora_seq.extra_hash(),
)

allocator.mark_blocks_as_computed([])

# Create different dummy sequences that have the same token IDs
# but different LoRA IDs.
seq = create_dummy_sequence(request_id=1,
token_ids=tokens,
block_size=block_size)

different_lora_seq = create_dummy_lora_sequence(request_id=2,
token_ids=tokens,
block_size=block_size,
lora_int_id=2)

# Due to the different LoRA IDs, corresponding blocks are not cached.
assert tracker.get_num_cached_tokens(seq) == 0
assert tracker.get_num_cached_tokens(different_lora_seq) == 0

# The number of cached tokens matches the length of the tokens
# for the cached LoRA sequence.
assert tracker.get_num_cached_tokens(lora_seq) == len(tokens)
10 changes: 10 additions & 0 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def create_dummy_prompt(
return prompt, seq_group


def create_dummy_lora_sequence(request_id: int, token_ids: List[int],
block_size: int, lora_int_id: int) -> Sequence:
return Sequence(seq_id=request_id,
inputs=token_inputs(token_ids),
block_size=block_size,
lora_request=LoRARequest(lora_name="dummy",
lora_path="/dummy",
lora_int_id=lora_int_id))


def create_dummy_sequence(request_id: int, token_ids: List[int],
block_size: int) -> Sequence:
return Sequence(
Expand Down
46 changes: 34 additions & 12 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def get_num_required_blocks(token_ids: List[int],

def allocate(self,
token_ids: List[int],
device: Device = Device.GPU) -> None:
device: Device = Device.GPU,
extra_hash: Optional[int] = None) -> None:
"""Allocates memory blocks for storing the given sequence of token IDs.
This method allocates the required number of blocks to store the given
Expand All @@ -90,12 +91,16 @@ def allocate(self,
token_ids (List[int]): The sequence of token IDs to be stored.
device (Device, optional): The device on which the blocks should be
allocated. Defaults to Device.GPU.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefixcaching block.
"""
assert not self._is_allocated
assert token_ids
blocks = self._allocate_blocks_for_token_ids(prev_block=None,
token_ids=token_ids,
device=device)
device=device,
extra_hash=extra_hash)
self.update(blocks)
self._num_full_slots = len(token_ids)

Expand All @@ -108,7 +113,8 @@ def update(self, blocks: List[Block]) -> None:
def append_token_ids(self,
token_ids: List[int],
num_lookahead_slots: int = 0,
num_computed_slots: Optional[int] = None) -> None:
num_computed_slots: Optional[int] = None,
extra_hash: Optional[int] = None) -> None:
"""Appends a sequence of token IDs to the existing blocks in the
BlockTable.
Expand All @@ -130,6 +136,9 @@ def append_token_ids(self,
Without sliding window, None can be passed.
Without chunked prefill, it should be the same as
_num_full_slots.
extra_hash (Optional[int]): The hash value of additional
factors such as adapters that influence the block, apart
from the token_ids.
"""
assert self._is_allocated, "no blocks have been allocated"
assert len(self._blocks) > 0
Expand All @@ -149,7 +158,8 @@ def append_token_ids(self,
# Ensure there are enough empty slots for the new tokens plus
# lookahead slots
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
num_lookahead_slots,
extra_hash=extra_hash)

# Update the blocks with the new tokens
first_block_idx = self._num_full_slots // self._block_size
Expand All @@ -160,7 +170,9 @@ def append_token_ids(self,

self._num_full_slots += len(token_ids)

def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
def ensure_num_empty_slots(self,
num_empty_slots: int,
extra_hash: Optional[int] = None) -> None:
"""Ensures that the BlockTable has at least the specified number of
empty slots available.
Expand All @@ -171,6 +183,9 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
Args:
num_empty_slots (int): The minimum number of empty slots required.
extra_hash (Optional[int]): The hash value of additional
factors such as adapters that influence the block, apart
from the token_ids.
"""
# Currently the block table only supports
# appending tokens to GPU blocks.
Expand All @@ -187,7 +202,9 @@ def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
assert len(self._blocks) > 0
self._blocks.append(
self._allocator.allocate_mutable_block(
prev_block=self._blocks[-1], device=device))
prev_block=self._blocks[-1],
device=device,
extra_hash=extra_hash))

def fork(self) -> "BlockTable":
"""Creates a new BlockTable instance with a copy of the blocks from the
Expand Down Expand Up @@ -259,9 +276,12 @@ def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
# ones after the appended ones.
return sequence_token_ids[self.num_full_slots:]

def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> List[Block]:
def _allocate_blocks_for_token_ids(
self,
prev_block: Optional[Block],
token_ids: List[int],
device: Device,
extra_hash: Optional[int] = None) -> List[Block]:
blocks: List[Block] = []

block_token_ids = []
Expand All @@ -275,16 +295,18 @@ def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
if block_token_ids:
blocks.extend(
self._allocator.allocate_immutable_blocks(
prev_block, block_token_ids=block_token_ids,
device=device))
prev_block,
block_token_ids=block_token_ids,
device=device,
extra_hash=extra_hash))
prev_block = blocks[-1]

if tail_token_ids:
assert len(tail_token_ids) == 1
cur_token_ids = tail_token_ids[0]

block = self._allocator.allocate_mutable_block(
prev_block=prev_block, device=device)
prev_block=prev_block, device=device, extra_hash=extra_hash)
block.append_token_ids(cur_token_ids)

blocks.append(block)
Expand Down
19 changes: 13 additions & 6 deletions vllm/core/block/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def __init__(self, block_size: int, create_block: Block.Factory,
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
block_id=None,
extra_hash=None))

def increase_pool(self):
"""Doubles the internal pool size
Expand All @@ -194,10 +195,15 @@ def increase_pool(self):
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))

def init_block(self, prev_block: Optional[Block], token_ids: List[int],
block_size: int, physical_block_id: Optional[int]) -> Block:
block_id=None,
extra_hash=None))

def init_block(self,
prev_block: Optional[Block],
token_ids: List[int],
block_size: int,
physical_block_id: Optional[int],
extra_hash: Optional[int] = None) -> Block:
if len(self._free_ids) == 0:
self.increase_pool()
assert len(self._free_ids) > 0
Expand All @@ -210,7 +216,8 @@ def init_block(self, prev_block: Optional[Block], token_ids: List[int],
token_ids=token_ids,
block_size=block_size,
allocator=block._allocator, # type: ignore[attr-defined]
block_id=physical_block_id)
block_id=physical_block_id,
extra_hash=extra_hash)
block.pool_id = pool_id # type: ignore[attr-defined]
return block

Expand Down
43 changes: 32 additions & 11 deletions vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,32 @@ def allocate_or_get_null_block(self) -> Block:
self.allocate_mutable_block(None, Device.GPU))
return self._null_block

def allocate_mutable_block(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable_block(self,
prev_block: Optional[Block],
device: Device,
extra_hash: Optional[int] = None) -> Block:
"""Allocates a new mutable block on the specified device.
Args:
prev_block (Optional[Block]): The previous block to in the sequence.
Used for prefix hashing.
device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns:
Block: The newly allocated mutable block.
"""
return self._allocators[device].allocate_mutable_block(prev_block)

def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device) -> List[Block]:
return self._allocators[device].allocate_mutable_block(
prev_block, extra_hash=extra_hash)

def allocate_immutable_blocks(
self,
prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Device,
extra_hash: Optional[int] = None) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.
Expand All @@ -147,17 +156,22 @@ def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids (List[int]): The list of block token IDs to be
stored in the new blocks.
device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns:
List[Block]: The newly allocated list of immutable blocks
containing the provided block token IDs.
"""
return self._allocators[device].allocate_immutable_blocks(
prev_block, block_token_ids)
prev_block, block_token_ids, extra_hash=extra_hash)

def allocate_immutable_block(self, prev_block: Optional[Block],
def allocate_immutable_block(self,
prev_block: Optional[Block],
token_ids: List[int],
device: Device) -> Block:
device: Device,
extra_hash: Optional[int] = None) -> Block:
"""Allocates a new immutable block with the provided token IDs on the
specified device.
Expand All @@ -167,13 +181,16 @@ def allocate_immutable_block(self, prev_block: Optional[Block],
token_ids (List[int]): The list of token IDs to be stored in the new
block.
device (Device): The device on which to allocate the new block.
extra_hash (Optional[int]): The hash value of additional
factors, such as adapters, that influence the block hash
in the prefix caching block.
Returns:
Block: The newly allocated immutable block containing the provided
token IDs.
"""
return self._allocators[device].allocate_immutable_block(
prev_block, token_ids)
prev_block, token_ids, extra_hash=extra_hash)

def free(self, block: Block) -> None:
"""Frees the memory occupied by the given block.
Expand Down Expand Up @@ -387,6 +404,10 @@ def is_full(self):
def prev_block(self):
return self._proxy.prev_block

@property
def extra_hash(self):
return None

@property
def computed(self):
return self._proxy.computed
Expand Down
Loading

0 comments on commit c31d4a5

Please sign in to comment.