Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Enable chunked-prefill and prefix cache with flash-attn backend #6144

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 107 additions & 1 deletion tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import pytest
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import get_open_port
from vllm.utils import STR_FLASH_ATTN_VAL, get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size


Expand Down Expand Up @@ -387,3 +388,108 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata),
vars(decode_meta_actual)):
assert attr_expected[1] == attr_actual[1]


@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches_with_prefix_caching(enforce_eager, monkeypatch):
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)

model_runner = _create_model_runner(
"facebook/opt-125m",
seed=0,
dtype="float16",
enforce_eager=enforce_eager,
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=True,
enable_prefix_caching=True,
)

seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = []
# Use a large number of blocks to test longer sequences
# with chunked prefill and prefix caching
block_tables = {0: list(range(128))}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you note the block size here for convenience?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed。


# case 1: prefix_cache_len <= context_len:
seq_len = 1000
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
token_chunk_size=200,
computed_block_nums=range(10),
)
seq_data.update_num_computed_tokens(200)
seq_group_metadata_list.append(seq_group_metadata)

# case 2: context_len < prefix_cache_len < seq_len
seq_len = 1000
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
token_chunk_size=100,
computed_block_nums=range(10),
)
seq_data.update_num_computed_tokens(80)
seq_group_metadata_list.append(seq_group_metadata)

# case 3: prefix_cache_len >= seq_len
seq_len = 1000
seq_lens.append(seq_len)
seq_data = SequenceData(list(range(seq_len)))
seq_group_metadata = SequenceGroupMetadata(
request_id="test_0",
is_prompt=True,
seq_data={0: seq_data},
sampling_params=SamplingParams(temperature=0),
block_tables=block_tables,
token_chunk_size=100,
computed_block_nums=range(10),
)
seq_data.update_num_computed_tokens(50)
seq_group_metadata_list.append(seq_group_metadata)

model_input = model_runner.prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)

# The nums of tokens to be computed are:
# - for the first sequence: no matched, and all of 200 tokens are
# left to be recomputed
# - for the second sequence: partially cached, and 20 tokens are
# left to recomputed
# - for the be third sequence: fully cached, and only 1 token is
# left to be recomputed
assert len(input_tokens) == 221
assert len(input_positions) == 221

torch.testing.assert_close(
attn_metadata.query_start_loc,
torch.tensor([0, 200, 220, 221],
dtype=torch.int32,
device=attn_metadata.query_start_loc.device))

torch.testing.assert_close(
attn_metadata.seq_start_loc,
torch.tensor([0, 400, 580, 730],
dtype=torch.int32,
device=attn_metadata.seq_start_loc.device))

torch.testing.assert_close(
attn_metadata.context_lens_tensor,
torch.tensor([200, 160, 149],
dtype=torch.int32,
device=attn_metadata.context_lens_tensor.device))
7 changes: 7 additions & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,13 @@ def _add_seq_group(
block_tables = inter_data.block_tables
computed_block_nums = inter_data.computed_block_nums

if chunked_prefill_enabled and inter_data.prefix_cache_hit:
raise RuntimeError(
"Chunked prefill and prefix caching cannot be used "
"simultaneously with flashinfer backend, try switching "
"to flash-attn backend by setting the environment variable "
"\"VLLM_ATTENTION_BACKEND=FLASH_ATTN\"")

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
Expand Down
7 changes: 7 additions & 0 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ def _add_seq_group(
block_tables = inter_data.block_tables
computed_block_nums = inter_data.computed_block_nums

if chunked_prefill_enabled and inter_data.prefix_cache_hit:
raise RuntimeError(
"Chunked prefill and prefix caching can only be used "
"simultaneously with flash-attn backend, try switching "
"to flash-attn backend by setting the environment variable "
"\"VLLM_ATTENTION_BACKEND=FLASH_ATTN\"")

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
Expand Down
14 changes: 10 additions & 4 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,10 +680,15 @@ def access_all_blocks_in_seq(
for block in block_table:
block.last_accessed = access_time

def compute_full_blocks_in_seq(self, seq: Sequence):
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1
# We ensure at least 1 token to prefill even fully matched in the
# model runner, so "computing computed_blocks as it is" is safe here.
max_full_block = min(
seq.get_prompt_len(),
seq.data.get_num_computed_tokens() +
token_chunk_size) // self.block_size
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
Expand Down Expand Up @@ -717,10 +722,11 @@ def get_common_computed_block_ids(
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
self.compute_full_blocks_in_seq(seq, token_chunk_size)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float):
self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now)

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
# The only need for mark block as computed is for prefix caching,
# while currently we could determine whether one block is computed
# or not by check whether it has content hash.
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass

def get_prefix_cache_hit_rate(self, device: Device) -> float:
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def get_common_computed_block_ids(
pass

@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass

@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# will crash the vLLM instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
scheduled_seq_group.seq_group,
scheduled_seq_group.token_chunk_size)

scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
Expand Down
37 changes: 29 additions & 8 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,20 +498,41 @@ def _compute_for_prefix_cache_hit(
and self.sliding_window is None
and inter_data.is_prompt)
inter_data.prefix_cache_hit = prefix_cache_hit
if self.chunked_prefill_enabled and prefix_cache_hit:
raise RuntimeError(
"chunked prefill cannot be used with prefix caching now.")

# If prefix cache is hit, advance context length to bypass
# hit blocks. Accordingly, input tokens, position and query length
# have to be updated.
if prefix_cache_hit:
assert computed_block_nums is not None
context_len = len(computed_block_nums) * self.block_size
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][context_len:]
inter_data.input_positions[seq_idx] = inter_data.input_positions[
seq_idx][context_len:]
prefix_cache_len = len(computed_block_nums) * self.block_size
# When prefix caching meets chunked prefill, we would be in
# one of the following three cases:
context_len = inter_data.context_lens[seq_idx]
seq_len = inter_data.seq_lens[seq_idx]
if prefix_cache_len <= context_len:
# Do normal chunked prefill.
pass
elif context_len < prefix_cache_len < seq_len:
# Advance the context_len to seq_len to prefill non-cached
# parts of the prompt.
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][(prefix_cache_len - context_len):]
inter_data.input_positions[
seq_idx] = inter_data.input_positions[seq_idx][(
prefix_cache_len - context_len):]
context_len = prefix_cache_len
elif seq_len <= prefix_cache_len:
# The current partial sequence is fully cache hit,
# and no further computation is needed. In this case,
# We leave at least 1 token for chunked prefill to prevent
# empty sequences in the attention computation.
inter_data.input_tokens[seq_idx] = inter_data.input_tokens[
seq_idx][(seq_len - 1 - context_len):]
inter_data.input_positions[
seq_idx] = inter_data.input_positions[seq_idx][(
seq_len - 1 - context_len):]
context_len = seq_len - 1

inter_data.context_lens[seq_idx] = context_len
inter_data.query_lens[
seq_idx] = inter_data.seq_lens[seq_idx] - context_len
Expand Down
Loading