-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] Sampling controller interface #6273
base: main
Are you sure you want to change the base?
Changes from 11 commits
5877a7a
0190aef
3c6723c
80c5091
1273203
3744143
a70c68e
039db20
a4db333
e0eb2da
6fec2b0
59f2e5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ | |
|
||
if TYPE_CHECKING: | ||
from vllm.inputs import LLMInputs | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.multimodal import MultiModalDataDict | ||
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics | ||
|
||
|
@@ -185,9 +186,14 @@ def get_num_computed_tokens(self) -> int: | |
|
||
def update_num_computed_tokens(self, num_new_computed_tokens: int): | ||
"""Update number of tokens computed so far.""" | ||
seq_len = self.get_len() | ||
self._num_computed_tokens += num_new_computed_tokens | ||
assert self._num_computed_tokens <= self.get_len(), ( | ||
self._num_computed_tokens, self.get_len()) | ||
# We can overflow by 1 if previous sampling was updated by | ||
# SamplingController to generate an empty sequence of tokens. | ||
if self._num_computed_tokens == seq_len + 1: | ||
self._num_computed_tokens = seq_len | ||
assert self._num_computed_tokens <= seq_len, ( | ||
self._num_computed_tokens, seq_len) | ||
# If all tokens are computed, it means it is in decoding phase. | ||
if self.get_num_uncomputed_tokens() == 0: | ||
self._stage = SequenceStage.DECODE | ||
|
@@ -468,8 +474,8 @@ def lora_int_id(self) -> int: | |
|
||
def get_last_latency(self, now: float) -> Optional[float]: | ||
"""Sets the last token time for Request level timings.""" | ||
# If still in prefill phase, raise Error. | ||
if self.is_prefill(): | ||
# If still in initial prefill phase, raise Error. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the case it is not "initial"? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This happens for fast-forward tokens - if there is more than one, we switch to prefill mode. |
||
if self.is_prefill() and self.get_seqs()[0].get_output_len() == 0: | ||
raise ValueError( | ||
"seq_group.get_last_latency() should not be called " | ||
"if the seq_group is in prefill phase.") | ||
|
@@ -701,6 +707,29 @@ def __init__( | |
self.parent_seq_id = parent_seq_id | ||
self.output_token = output_token | ||
self.logprobs = logprobs | ||
# If present, these tokens should appended to the output | ||
# instead of output_token. | ||
self.fast_forward_tokens: Optional[List[int]] = None | ||
|
||
def append_to(self, seq: Sequence) -> None: | ||
mmoskal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.fast_forward_tokens is not None: | ||
logprobs = self.logprobs | ||
for token in self.fast_forward_tokens: | ||
# On first iteration, use the existing self.logprobs, provided | ||
# they contain the token. | ||
if token not in logprobs: | ||
logprobs = { | ||
token: Logprob(logprob=0.0, rank=1, decoded_token=None) | ||
} | ||
seq.append_token_id(token, logprobs) | ||
# On subsequent iterations always use artificially created | ||
# logprobs. | ||
logprobs = {} | ||
# If more than one token was appended, switch to prefill stage. | ||
if seq.data.get_num_uncomputed_tokens() > 1: | ||
seq.data._stage = SequenceStage.PREFILL | ||
else: | ||
seq.append_token_id(self.output_token, self.logprobs) | ||
|
||
def __repr__(self) -> str: | ||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " | ||
|
@@ -912,6 +941,26 @@ def prune(self, | |
self.seq_ids = seq_ids | ||
|
||
|
||
class SamplingController: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make it an abstract class? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also in the docstring, we should probably mention this class is a singleton and stateful. And prepare & transform logits need to be called 1 tot 1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added docstrings. It seems python doesn't have abstract classes, only abstract methods. I can imagine use cases where you only override some of these methods, and none of them you have to override, so I don't think it makes sense to make any of them abstract. One thing we could do is to make the engine.sampling_controller field non-optional and use this base class as a no-op implementation. Not sure if that would be cleaner though? |
||
|
||
def prepare(self, sampling_metadata: "SamplingMetadata"): | ||
"""Prepare the sampling controller for the next step.""" | ||
pass | ||
|
||
def empty_step(self): | ||
"""Called instead of prepare() when the scheduler found no sequences | ||
to run.""" | ||
pass | ||
|
||
def transform_logits(self, logits: torch.Tensor) -> torch.Tensor: | ||
"""Apply the sampling controller to the logits.""" | ||
return logits | ||
|
||
def transform_sampler_output(self, output: SamplerOutput) -> SamplerOutput: | ||
"""Apply the sampling controller to the sampler output.""" | ||
return output | ||
|
||
|
||
@dataclass | ||
class ExecuteModelRequest: | ||
"""The model execution request, containing CPU metadata only. The LLM | ||
|
@@ -936,6 +985,8 @@ class ExecuteModelRequest: | |
num_steps: int = 1 | ||
# Finished request ids since last step. | ||
finished_requests_ids: List[str] = field(default_factory=list) | ||
# Sampling controller to use for this step. | ||
sampling_controller: Optional[SamplingController] = None | ||
|
||
def clone( | ||
self, seq_group_metadata_list: List[SequenceGroupMetadata] | ||
|
@@ -951,4 +1002,5 @@ def clone( | |
running_queue_size=self.running_queue_size, | ||
previous_hidden_states=self.previous_hidden_states, | ||
num_steps=self.num_steps, | ||
finished_requests_ids=self.finished_requests_ids) | ||
finished_requests_ids=self.finished_requests_ids, | ||
sampling_controller=self.sampling_controller) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1240,6 +1240,11 @@ def execute_model( | |
"finished_requests_ids": model_input.finished_requests_ids, | ||
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, | ||
} if self.has_seqlen_agnostic else {} | ||
|
||
if (ctrl := model_input.sampling_controller) is not None: | ||
assert model_input.sampling_metadata is not None | ||
ctrl.prepare(model_input.sampling_metadata) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has been changed from RFC (we pass seq group metadata, here, we are using sampling params). Is there any way to hook this with seq group metadata? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. yeah I guess what we have in sampling metdata is probably sufficient. let me get back to you soon about this. |
||
|
||
hidden_or_intermediate_states = model_executable( | ||
input_ids=model_input.input_tokens, | ||
positions=model_input.input_positions, | ||
|
@@ -1259,12 +1264,18 @@ def execute_model( | |
if not self.is_driver_worker: | ||
return [] | ||
|
||
if ctrl is not None: | ||
logits = ctrl.transform_logits(logits) | ||
|
||
# Sample the next token. | ||
output: SamplerOutput = self.model.sample( | ||
logits=logits, | ||
sampling_metadata=model_input.sampling_metadata, | ||
) | ||
|
||
if ctrl is not None: | ||
output = ctrl.transform_sampler_output(output) | ||
|
||
if self.return_hidden_states: | ||
# we only need to pass hidden states of most recent token | ||
assert model_input.sampling_metadata is not None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ: is this change related to this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this happens when the "fast forward" is 0-tokens long - see comment in PR description about
BlockTable.append_token_ids