Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Sampling controller interface #6273

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
9 changes: 6 additions & 3 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,13 @@ def append_token_ids(self,

# Update the blocks with the new tokens
first_block_idx = self._num_full_slots // self._block_size
token_blocks = self._chunk_token_blocks_for_append(token_ids)

for i, token_block in enumerate(token_blocks):
self._blocks.append_token_ids(first_block_idx + i, token_block)
# don't bother appending anything, if no new token_ids were generated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: is this change related to this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this happens when the "fast forward" is 0-tokens long - see comment in PR description about BlockTable.append_token_ids

if token_ids:
token_blocks = self._chunk_token_blocks_for_append(token_ids)

for i, token_block in enumerate(token_blocks):
self._blocks.append_token_ids(first_block_idx + i, token_block)

self._num_full_slots += len(token_ids)

Expand Down
5 changes: 4 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ async def step_async(
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids)
finished_requests_ids=finished_requests_ids,
sampling_controller=self.sampling_controller)
output = await self.model_executor.execute_model_async(
execute_model_req)
else:
Expand All @@ -257,6 +258,8 @@ async def step_async(

async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
if ctrl := self.sampling_controller:
ctrl.empty_step()
await self.model_executor.stop_remote_worker_execution_loop_async()

async def process_model_inputs_async(
Expand Down
12 changes: 9 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
PoolerOutput, SamplerOutput, SamplingController,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
Expand Down Expand Up @@ -225,6 +225,7 @@ def __init__(
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
self.sampling_controller: Optional[SamplingController] = None

if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
Expand Down Expand Up @@ -528,6 +529,8 @@ def _add_processed_request(
min_cost_scheduler.add_seq_group(seq_group)

def stop_remote_worker_execution_loop(self) -> None:
if ctrl := self.sampling_controller:
ctrl.empty_step()
self.model_executor.stop_remote_worker_execution_loop()

def process_model_inputs(
Expand Down Expand Up @@ -857,10 +860,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids)
finished_requests_ids=finished_requests_ids,
sampling_controller=self.sampling_controller)
output = self.model_executor.execute_model(
execute_model_req=execute_model_req)
else:
if ctrl := self.sampling_controller:
ctrl.empty_step()
output = []

request_outputs = self._process_model_outputs(
Expand Down
6 changes: 2 additions & 4 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,13 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
for child_sample in child_samples[:-1]:
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs)
child_sample.append_to(child)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
last_child_sample.append_to(parent)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class SequenceGroupToSample:
# |-- query_len ---|

# Sequence ids for the sequence group in a previous step.
request_id: str
seq_ids: List[int]
sampling_params: SamplingParams
# seq_id -> sequence data.
Expand Down Expand Up @@ -273,6 +274,7 @@ def sample(logits):

seq_groups.append(
SequenceGroupToSample(
request_id=seq_group_metadata.request_id,
seq_ids=seq_ids,
sampling_params=sampling_params,
seq_data=seq_group_metadata.seq_data,
Expand Down
62 changes: 57 additions & 5 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

if TYPE_CHECKING:
from vllm.inputs import LLMInputs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MultiModalDataDict
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

Expand Down Expand Up @@ -185,9 +186,14 @@ def get_num_computed_tokens(self) -> int:

def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
seq_len = self.get_len()
self._num_computed_tokens += num_new_computed_tokens
assert self._num_computed_tokens <= self.get_len(), (
self._num_computed_tokens, self.get_len())
# We can overflow by 1 if previous sampling was updated by
# SamplingController to generate an empty sequence of tokens.
if self._num_computed_tokens == seq_len + 1:
self._num_computed_tokens = seq_len
assert self._num_computed_tokens <= seq_len, (
self._num_computed_tokens, seq_len)
# If all tokens are computed, it means it is in decoding phase.
if self.get_num_uncomputed_tokens() == 0:
self._stage = SequenceStage.DECODE
Expand Down Expand Up @@ -468,8 +474,8 @@ def lora_int_id(self) -> int:

def get_last_latency(self, now: float) -> Optional[float]:
"""Sets the last token time for Request level timings."""
# If still in prefill phase, raise Error.
if self.is_prefill():
# If still in initial prefill phase, raise Error.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the case it is not "initial"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This happens for fast-forward tokens - if there is more than one, we switch to prefill mode.

if self.is_prefill() and self.get_seqs()[0].get_output_len() == 0:
raise ValueError(
"seq_group.get_last_latency() should not be called "
"if the seq_group is in prefill phase.")
Expand Down Expand Up @@ -701,6 +707,29 @@ def __init__(
self.parent_seq_id = parent_seq_id
self.output_token = output_token
self.logprobs = logprobs
# If present, these tokens should appended to the output
# instead of output_token.
self.fast_forward_tokens: Optional[List[int]] = None

def append_to(self, seq: Sequence) -> None:
mmoskal marked this conversation as resolved.
Show resolved Hide resolved
if self.fast_forward_tokens is not None:
logprobs = self.logprobs
for token in self.fast_forward_tokens:
# On first iteration, use the existing self.logprobs, provided
# they contain the token.
if token not in logprobs:
logprobs = {
token: Logprob(logprob=0.0, rank=1, decoded_token=None)
}
seq.append_token_id(token, logprobs)
# On subsequent iterations always use artificially created
# logprobs.
logprobs = {}
# If more than one token was appended, switch to prefill stage.
if seq.data.get_num_uncomputed_tokens() > 1:
seq.data._stage = SequenceStage.PREFILL
else:
seq.append_token_id(self.output_token, self.logprobs)

def __repr__(self) -> str:
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
Expand Down Expand Up @@ -912,6 +941,26 @@ def prune(self,
self.seq_ids = seq_ids


class SamplingController:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make it an abstract class?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also in the docstring, we should probably mention this class is a singleton and stateful. And prepare & transform logits need to be called 1 tot 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added docstrings.

It seems python doesn't have abstract classes, only abstract methods. I can imagine use cases where you only override some of these methods, and none of them you have to override, so I don't think it makes sense to make any of them abstract.

One thing we could do is to make the engine.sampling_controller field non-optional and use this base class as a no-op implementation. Not sure if that would be cleaner though?


def prepare(self, sampling_metadata: "SamplingMetadata"):
"""Prepare the sampling controller for the next step."""
pass

def empty_step(self):
"""Called instead of prepare() when the scheduler found no sequences
to run."""
pass

def transform_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""Apply the sampling controller to the logits."""
return logits

def transform_sampler_output(self, output: SamplerOutput) -> SamplerOutput:
"""Apply the sampling controller to the sampler output."""
return output


@dataclass
class ExecuteModelRequest:
"""The model execution request, containing CPU metadata only. The LLM
Expand All @@ -936,6 +985,8 @@ class ExecuteModelRequest:
num_steps: int = 1
# Finished request ids since last step.
finished_requests_ids: List[str] = field(default_factory=list)
# Sampling controller to use for this step.
sampling_controller: Optional[SamplingController] = None

def clone(
self, seq_group_metadata_list: List[SequenceGroupMetadata]
Expand All @@ -951,4 +1002,5 @@ def clone(
running_queue_size=self.running_queue_size,
previous_hidden_states=self.previous_hidden_states,
num_steps=self.num_steps,
finished_requests_ids=self.finished_requests_ids)
finished_requests_ids=self.finished_requests_ids,
sampling_controller=self.sampling_controller)
2 changes: 1 addition & 1 deletion vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def decode_sequence_inplace(self, seq: Sequence,
)

# Decode logprobs
logprobs = seq.output_logprobs[-1]
logprobs = seq.output_logprobs[-1] if seq.output_logprobs else None
if logprobs:
previous_tokens = all_input_ids[:-1]
for token_id, sample_logprob in logprobs.items():
Expand Down
11 changes: 11 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,11 @@ def execute_model(
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}

if (ctrl := model_input.sampling_controller) is not None:
assert model_input.sampling_metadata is not None
ctrl.prepare(model_input.sampling_metadata)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been changed from RFC (we pass seq group metadata, here, we are using sampling params). Is there any way to hook this with seq group metadata?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are SamplingParams in sampling_metadata (specifically in SequenceGroupToSample); there is also request_id and sequence_id - I think this is the stuff you said you needed. The main reason to use sampling metadata is that it also has all the logit positions already computed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. yeah I guess what we have in sampling metdata is probably sufficient. let me get back to you soon about this.


hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
Expand All @@ -1259,12 +1264,18 @@ def execute_model(
if not self.is_driver_worker:
return []

if ctrl is not None:
logits = ctrl.transform_logits(logits)

# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)

if ctrl is not None:
output = ctrl.transform_sampler_output(output)

if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/model_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
SamplingController, SequenceGroupMetadata)

if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
Expand Down Expand Up @@ -92,6 +92,8 @@ class ModelRunnerInputBase(ABC):
serialize/deserialize a ModelInput for broadcast between workers.
"""

sampling_controller: Optional[SamplingController] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
"""
Extract broadcastable fields. Override for fields that require some
Expand Down
7 changes: 7 additions & 0 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,13 @@ def execute_model(
model_input.as_broadcastable_tensor_dict())
broadcast_data["num_steps"] = num_steps
broadcast_tensor_dict(broadcast_data, src=0)

# SamplingController is only used in the driver worker, so it
# doesn't need to be broadcasted.
ctrl = execute_model_req.sampling_controller
if ctrl is not None:
model_input = dataclasses.replace(model_input,
sampling_controller=ctrl)
else:
assert self.do_metadata_broadcast
broadcast_data = broadcast_tensor_dict(src=0)
Expand Down
Loading