-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Performance] Introducing Prefix-Cached Chunked Prefill with flash-attn backend and 10% throughput gained under prompt <1K #6819
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -246,6 +246,9 @@ def __init__( | |
|
||
self.multi_modal_inputs = multi_modal_inputs | ||
self.prefix_cache_hit = prefix_cache_hit | ||
|
||
# maybe dirty hack | ||
self.cached_len = 0 | ||
|
||
self.__post_init__() | ||
|
||
|
@@ -365,20 +368,36 @@ def _compute_for_prefix_cache_hit( | |
and self.sliding_window is None | ||
and inter_data.is_prompt) | ||
inter_data.prefix_cache_hit = prefix_cache_hit | ||
if self.chunked_prefill_enabled and prefix_cache_hit: | ||
raise RuntimeError( | ||
"chunked prefill cannot be used with prefix caching now.") | ||
# if self.chunked_prefill_enabled and prefix_cache_hit: | ||
# raise RuntimeError( | ||
# "chunked prefill cannot be used with prefix caching now.") | ||
|
||
# If prefix cache is hit, advance context length to bypass | ||
# hit blocks. Accordingly, input tokens, position and query length | ||
# have to be updated. | ||
if prefix_cache_hit: | ||
assert computed_block_nums is not None | ||
context_len = len(computed_block_nums) * self.block_size | ||
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ | ||
seq_idx][context_len:] | ||
inter_data.input_positions[seq_idx] = inter_data.input_positions[ | ||
seq_idx][context_len:] | ||
|
||
# [help wanted] | ||
# inter_data.cached_len is just a work-around for mutable cached length when chunked prefill | ||
# enabled with prefix caching, in order to fix the cached length for the same seq being chunked | ||
context_len = inter_data.context_lens[seq_idx] | ||
if context_len == 0: | ||
inter_data.cached_len = len(computed_block_nums) * self.block_size | ||
context_len = min(inter_data.cached_len, seq_group_metadata.token_chunk_size - 1) | ||
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ | ||
seq_idx][context_len:] | ||
inter_data.input_positions[seq_idx] = inter_data.input_positions[ | ||
seq_idx][context_len:] | ||
else: | ||
if inter_data.cached_len > context_len: | ||
delta_len = min(inter_data.cached_len - context_len, seq_group_metadata.token_chunk_size - 1) | ||
context_len += delta_len | ||
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ | ||
seq_idx][delta_len:] | ||
inter_data.input_positions[seq_idx] = inter_data.input_positions[ | ||
seq_idx][delta_len:] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @Juelianqvq, I have read the diff and if I understand it correctly, the key different between this PR with #6144 is: leaving at least 1 token for prefill for each sequence. I could add such logic into #6144. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sighingnow Actually not only that, you have mentioned the correctness of keeping at least 1 tokens which you have missed before. Moreover, the modification in block_manager.py matters too. Just have a try and see whether you can have a inference speed up with only modifying keeping token logic. I've got the answer because I've developed in so many cases and pointed out the existing problem in my PR which certainly behaves faster using a work-around way. |
||
|
||
inter_data.context_lens[seq_idx] = context_len | ||
inter_data.query_lens[ | ||
seq_idx] = inter_data.seq_lens[seq_idx] - context_len | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove it?