Skip to content

Commit

Permalink
Enable chunked-prefill and prefix cache with flash-attn backend
Browse files Browse the repository at this point in the history
Signed-off-by: Tao He <sighingnow@gmail.com>
  • Loading branch information
sighingnow committed Jul 5, 2024
1 parent 56b325e commit 3963c76
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def _prepare_model_input_tensors(
computed_block_nums = seq_group_metadata.computed_block_nums
if (self.scheduler_config is not None
and self.scheduler_config.chunked_prefill_enabled
and self.attn_backend.get_name() != "flash-attn"
and not (computed_block_nums is None
or computed_block_nums == [])):
raise RuntimeError(
Expand Down Expand Up @@ -464,8 +465,26 @@ def _prepare_model_input_tensors(
# NOTE: This only works for oooooooxxx style attention.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
tokens = tokens[context_len:]
prefix_cache_len = (len(computed_block_nums) *
self.block_size)

# When prefix caching meets chunked prefill, we would be in
# the following three conditions:
#
# - prefix_cache_len <= context_len:
# - do normal chunked prefill and nothing special
# - context_len < prefix_cache_len < seq_len:
# - advance the context_len to seq_len to perform non-
# cached parts of the sequence.
# - prefix_cache_len >= seq_len:
# - it means the current partial sequence is fully cache
# hited, and no further computation is needed.
if context_len < prefix_cache_len < seq_len:
tokens = tokens[(prefix_cache_len - context_len):]
context_len = prefix_cache_len
elif seq_len <= prefix_cache_len:
tokens = []
context_len = seq_len

# need to think what to set it to when we have both sliding
# window and prefix caching...
Expand Down

0 comments on commit 3963c76

Please sign in to comment.