Skip to content

Commit

Permalink
FIxes the common_computed_block_nums updating.
Browse files Browse the repository at this point in the history
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
  • Loading branch information
sighingnow committed Aug 21, 2024
1 parent b4867ba commit a043643
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 11 deletions.
5 changes: 2 additions & 3 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import pytest
import torch

from tests.kernels.utils import (STR_FLASH_ATTN_VAL,
override_backend_env_variable)
from tests.kernels.utils import override_backend_env_variable
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SamplingParams,
SequenceData, SequenceGroupMetadata)
from vllm.utils import get_open_port
from vllm.utils import STR_FLASH_ATTN_VAL, get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size


Expand Down
14 changes: 10 additions & 4 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,10 +680,15 @@ def access_all_blocks_in_seq(
for block in block_table:
block.last_accessed = access_time

def compute_full_blocks_in_seq(self, seq: Sequence):
def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // self.block_size - 1
# We ensure at least 1 token to prefill even fully matched in the
# model runner, so "computing computed_blocks as it is" is safe here.
max_full_block = min(
seq.get_prompt_len(),
seq.data.get_num_computed_tokens() +
token_chunk_size) // self.block_size
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
Expand Down Expand Up @@ -717,10 +722,11 @@ def get_common_computed_block_ids(
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
return commonprefix([ids for ids in ids_list if ids != []])

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
self.compute_full_blocks_in_seq(seq, token_chunk_size)

def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float):
self._last_access_blocks_tracker.update_last_access(
seq.seq_id, now)

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
# The only need for mark block as computed is for prefix caching,
# while currently we could determine whether one block is computed
# or not by check whether it has content hash.
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass

def get_prefix_cache_hit_rate(self, device: Device) -> float:
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def get_common_computed_block_ids(
pass

@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
def mark_blocks_as_computed(self, seq_group: SequenceGroup,
token_chunk_size: int):
pass

@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
# will crash the vLLM instance / will not retry.
for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups:
self.block_manager.mark_blocks_as_computed(
scheduled_seq_group.seq_group)
scheduled_seq_group.seq_group,
scheduled_seq_group.token_chunk_size)

scheduler_time = time.perf_counter() - scheduler_start_time
# Add this to scheduler time to all the sequences that are currently
Expand Down

0 comments on commit a043643

Please sign in to comment.