Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] support LoRA and prompt adapter in content-based hashing for Block Manager v2 prefix caching #8240

Merged
merged 16 commits into from
Dec 13, 2024

Conversation

llsj14
Copy link
Contributor

@llsj14 llsj14 commented Sep 6, 2024

Summary

Block Manager v2, unlike v1, did not support LoRA and prompt adapter for the block hash in prefix caching mode.
I added logic to inject the LoRA ID and prompt adapter ID into the block hash function to support LoRA and prompt adapter while using prefix caching mode with block manager v2.

Detail

Block Manager v1 uses the following hash_of_block function to generate a content hash in prefix caching mode:

vllm/vllm/sequence.py

Lines 460 to 468 in baa5467

def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
# Compute the number of tokens in the sequence
# TODO: The current hashing function is O(L^2). We should optimize
# this in the future.
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
return hash((hashed_tokens, self.lora_int_id))

However, Block Manager v2 only uses token IDs, as shown here:

return hash((is_first_block, prev_block_hash, *cur_block_token_ids))

Copy link

github-actions bot commented Sep 6, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@@ -149,7 +149,10 @@ def _allocate_sequence(self, seq: Sequence) -> BlockTable:
block_allocator=self.block_allocator,
max_block_sliding_window=self.max_block_sliding_window,
)
block_table.allocate(seq.get_token_ids())

contextual_hash = hash((seq.prompt_adapter_id, seq.lora_int_id))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inject the hash value from the prompt adapter ID and LoRA ID, which will be used in the hash of the prefix caching block.


Returns:
- int: The computed hash value for the block.
"""
assert (prev_block_hash is None) == is_first_block
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
contextual_hash))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The part where the final hash value is generated.

@llsj14
Copy link
Contributor Author

llsj14 commented Sep 12, 2024

@alexm-neuralmagic @youkaichao
May I request a review to apply prefix caching with LoRA and a prompt adapter?

Copy link

mergify bot commented Nov 26, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @llsj14.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 26, 2024
@llsj14 llsj14 force-pushed the feat/block-manager-v2-hash branch from cf4d3d7 to b264d7d Compare December 11, 2024 14:13
@mergify mergify bot added documentation Improvements or additions to documentation frontend labels Dec 11, 2024
…ock Manager v2

Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
@llsj14 llsj14 force-pushed the feat/block-manager-v2-hash branch from b264d7d to 44fb3f0 Compare December 11, 2024 14:22
@mergify mergify bot removed the needs-rebase label Dec 11, 2024
@llsj14
Copy link
Contributor Author

llsj14 commented Dec 11, 2024

I'm sorry for the incorrect labeling and the notification. I made a mistake while rebasing my code.

@rickyyx Could you also review my code, please? This PR applies LoRA and Prefix Adapter while leveraging prefix caching, which is closely related to your recent PRs.

Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
@llsj14 llsj14 force-pushed the feat/block-manager-v2-hash branch from 9159cc3 to 3daa8dc Compare December 11, 2024 15:39
@comaniac comaniac self-assigned this Dec 11, 2024
Copy link
Contributor

@rickyyx rickyyx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach looks good to me - I had tried something similar but even further into including token ids as part of the contextual_hash before, which I think might yield potential perf benefits by avoiding append token ids to the blocks.

A few nits:

  1. the name of contextual_hash is not too straightforward for me. Personally i feel something like aux_hash_metadata, extra_hash_data or something similar. But just nit.
  2. Could we have some tests?

Also, cc @comaniac

Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM. Few things:

  1. Please fix the default values. We should use None instead of 0.
  2. Please add unit tests.
  3. For naming, I agree with @rickyyx that extra_hash might be more generic.

@@ -80,7 +80,8 @@ def get_num_required_blocks(token_ids: List[int],

def allocate(self,
token_ids: List[int],
device: Device = Device.GPU) -> None:
device: Device = Device.GPU,
contextual_hash: Optional[int] = 0) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value should be None as 0 is potentially a valid hash value.

@@ -177,7 +177,8 @@ def __init__(self, block_size: int, create_block: Block.Factory,
token_ids=[],
block_size=self._block_size,
allocator=self._allocator,
block_id=None))
block_id=None,
contextual_hash=0))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again do not use 0 as the default value.

@@ -50,6 +50,11 @@ def is_full(self) -> bool:
def prev_block(self) -> Optional["Block"]:
pass

@property
@abstractmethod
def contextual_hash(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotation

vllm/core/block/interfaces.py Show resolved Hide resolved
@@ -99,18 +106,21 @@ def content_hash(self) -> Optional[int]:
class BlockAllocator(ABC):

@abstractmethod
def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block:
def allocate_mutable_block(self, prev_block: Optional[Block],
contextual_hash: Optional[int]) -> Block:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto to all similar places.

Suggested change
contextual_hash: Optional[int]) -> Block:
contextual_hash: Optional[int] = None) -> Block:

vllm/sequence.py Outdated
@@ -527,6 +527,15 @@ def hash_of_block(self, logical_idx: int) -> int:
hashed_tokens = self.data.get_prefix_token_ids(num_tokens)
return hash((hashed_tokens, self.lora_int_id))

def contextual_hash_of_block(self) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence should not have the concept of "block".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. The hash_of_block function is used by Block Manager V1, so I followed that convention. I calculate an additional hash using only the LoRA ID or prompt adapter, so I plan to remove the concept of "block" in this context.

Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com>
Copy link
Collaborator

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise LGTM

vllm/core/block/prefix_caching_block.py Show resolved Hide resolved
@comaniac comaniac added ready ONLY add when PR is ready to merge/full CI is needed and removed documentation Improvements or additions to documentation frontend labels Dec 13, 2024
@llsj14
Copy link
Contributor Author

llsj14 commented Dec 13, 2024

Thank you for the detailed feedback and great reviews! @comaniac @rickyyx

@comaniac comaniac merged commit c31d4a5 into vllm-project:main Dec 13, 2024
64 checks passed
BKitor pushed a commit to BKitor/vllm that referenced this pull request Dec 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants