diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 0244919152cad..912abc554d63f 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -11,7 +11,7 @@ ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.multi_step_model_runner import ( - MutableModelInputForGPUWithMultiStepMetadata) + StatefulModelInput) class MockAttentionBackend(AttentionBackend): @@ -177,7 +177,7 @@ def test_multi_step_model_runner_input(): sampling_metadata=sampling_metadata, attn_metadata=attn_metadata) - model_input = MutableModelInputForGPUWithMultiStepMetadata( + model_input = StatefulModelInput( frozen_model_input=frozen_model_input, is_last_step=True, is_first_multi_step=False, @@ -190,12 +190,12 @@ def test_multi_step_model_runner_input(): ) assert isinstance(model_input, - MutableModelInputForGPUWithMultiStepMetadata) + StatefulModelInput) # Test round trip serialization. tensor_dict = model_input.as_broadcastable_tensor_dict() attn_backend = MockAttentionBackend() - received_model_input = (MutableModelInputForGPUWithMultiStepMetadata. + received_model_input = (StatefulModelInput. from_broadcasted_tensor_dict( tensor_dict, attn_backend=attn_backend)) @@ -203,7 +203,7 @@ def test_multi_step_model_runner_input(): # Check that received copy has correct values. assert isinstance(received_model_input, - MutableModelInputForGPUWithMultiStepMetadata) + StatefulModelInput) assert receieved_frozen_input.input_tokens is not None assert (receieved_frozen_input.input_tokens == frozen_model_input.input_tokens).all() diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 5f6c31061e6c9..9c4eae08a9fbf 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -54,7 +54,7 @@ class ModelOutput: def pythonize( self, - input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + input_metadata: "StatefulModelInput", copy_stream: torch.cuda.Stream, pinned_sampled_token_buffer: torch.Tensor) -> None: """Pythonize the output. Blocking.""" @@ -65,7 +65,7 @@ def pythonize( def maybe_pythonize( self, - input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + input_metadata: "StatefulModelInput", copy_stream: torch.cuda.Stream, pinned_sampled_token_buffer: torch.Tensor) -> None: """Pythonize the output if ready, else return None. Non-blocking.""" @@ -76,7 +76,7 @@ def maybe_pythonize( def _pythonize_sampler_output( self, - input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + input_metadata: "StatefulModelInput", copy_stream: torch.cuda.Stream, pinned_sampled_token_buffer: torch.Tensor, blocking: bool) -> bool: """ @@ -97,7 +97,7 @@ def _pythonize_sampler_output( @dataclass(frozen=False) -class MutableModelInputForGPUWithMultiStepMetadata(BroadcastableModelInput): +class StatefulModelInput(BroadcastableModelInput): # actual frozen model input dataclass passed to _base_model_runner frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None @@ -137,7 +137,7 @@ def from_broadcasted_tensor_dict( cls, tensor_dict: Dict[str, Any], attn_backend: Optional["AttentionBackend"] = None, - ) -> "MutableModelInputForGPUWithMultiStepMetadata": + ) -> "StatefulModelInput": tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( @@ -182,7 +182,7 @@ def add_sampler_output(self, # metadata # mypy: disable-error-code=type-var class MultiStepModelRunner( - GPUModelRunnerBase[MutableModelInputForGPUWithMultiStepMetadata]): + GPUModelRunnerBase[StatefulModelInput]): # mypy: enable-error-code=type-var def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): @@ -199,8 +199,8 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any] - ) -> MutableModelInputForGPUWithMultiStepMetadata: - model_input = (MutableModelInputForGPUWithMultiStepMetadata. + ) -> StatefulModelInput: + model_input = (StatefulModelInput. from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, @@ -212,11 +212,11 @@ def prepare_model_input( seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, finished_requests_ids: Optional[List[str]] = None - ) -> MutableModelInputForGPUWithMultiStepMetadata: + ) -> StatefulModelInput: frozen_model_input = self._base_model_runner.prepare_model_input( seq_group_metadata_list, virtual_engine, finished_requests_ids) - model_input = MutableModelInputForGPUWithMultiStepMetadata( + model_input = StatefulModelInput( frozen_model_input=frozen_model_input, num_seqs=len(frozen_model_input.seq_lens), num_queries=len(frozen_model_input.query_lens), @@ -226,7 +226,7 @@ def prepare_model_input( @torch.inference_mode() def execute_model( self, - model_input: MutableModelInputForGPUWithMultiStepMetadata, + model_input: StatefulModelInput, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, @@ -354,9 +354,9 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, assert seq_group.query_len is None # Decode def _advance_step( - self, model_input: MutableModelInputForGPUWithMultiStepMetadata, + self, model_input: StatefulModelInput, out: SamplerOutput - ) -> MutableModelInputForGPUWithMultiStepMetadata: + ) -> StatefulModelInput: frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None assert frozen_model_input.attn_metadata is not None @@ -420,7 +420,7 @@ def vocab_size(self) -> int: def _pythonize_sampler_output( - model_input: MutableModelInputForGPUWithMultiStepMetadata, + model_input: StatefulModelInput, output: SamplerOutput, pinned_sampled_token_buffer: torch.Tensor, sampled_token_ids: torch.Tensor) -> SamplerOutput: """ This function is only called when the output tensors are ready. diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index bbc18c28959dd..7f7b7bdd632d4 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -5,14 +5,14 @@ from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.worker.model_runner_base import BroadcastableModelInput from vllm.worker.multi_step_model_runner import ( - MultiStepModelRunner, MutableModelInputForGPUWithMultiStepMetadata) + MultiStepModelRunner, StatefulModelInput) from vllm.worker.worker import Worker, WorkerInput @dataclass class MultiStepState: worker_input: WorkerInput - model_input: MutableModelInputForGPUWithMultiStepMetadata + model_input: StatefulModelInput class MultiStepWorker(Worker): @@ -54,7 +54,7 @@ def _get_driver_input_and_broadcast( # on first step we prepare the worker input and model input normally worker_input: WorkerInput = self.prepare_worker_input( execute_model_req=execute_model_req) - model_input: MutableModelInputForGPUWithMultiStepMetadata = ( + model_input: StatefulModelInput = ( self.model_runner.prepare_model_input( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, @@ -90,7 +90,7 @@ def _get_driver_input_and_broadcast( def _prepare_last_sampled_token_ids_for_tp_workers( self, execute_model_req: ExecuteModelRequest, - model_input: MutableModelInputForGPUWithMultiStepMetadata, + model_input: StatefulModelInput, ) -> None: """ Prepare the last sampled token ids for TP workers. If it's the last @@ -130,7 +130,7 @@ def _prepare_last_sampled_token_ids_for_tp_workers( def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None, - ) -> Optional[Tuple[MutableModelInputForGPUWithMultiStepMetadata, + ) -> Optional[Tuple[StatefulModelInput, WorkerInput]]: """ Depending on the current state of the request and multi step worker, @@ -152,7 +152,7 @@ def prepare_input( model_input, worker_input = self._get_driver_input_and_broadcast( execute_model_req) assert isinstance(model_input, - MutableModelInputForGPUWithMultiStepMetadata) + StatefulModelInput) if execute_model_req.is_first_multi_step: # cache the worker input and model input for the next steps self.multi_step_states[virtual_engine] = MultiStepState( @@ -166,7 +166,7 @@ def prepare_input( return None model_input, worker_input = broadcast_data assert isinstance(model_input, - MutableModelInputForGPUWithMultiStepMetadata) + StatefulModelInput) virtual_engine = worker_input.virtual_engine if model_input.is_first_multi_step: pass @@ -180,7 +180,7 @@ def prepare_input( # for the next step (sampled_token_ids from the previous step) assert isinstance( - model_input, MutableModelInputForGPUWithMultiStepMetadata) + model_input, StatefulModelInput) # we need to update the last sampled token ids in the model # input for the workers so that they can run inplace # advance_step