diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 90b844bf42139..4c30f20ff076e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -264,6 +264,7 @@ def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: virtual_engine = execute_model_req.virtual_engine + num_steps = execute_model_req.num_steps num_seq_groups = len(execute_model_req.seq_group_metadata_list) # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. @@ -286,6 +287,7 @@ def prepare_worker_input( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, + num_steps=num_steps, ) @torch.inference_mode() diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 20db3dad1caab..85ab0d348e03d 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -129,6 +129,7 @@ class WorkerInput: blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None virtual_engine: int = 0 + num_steps: int = 1 @classmethod def from_broadcasted_tensor_dict( @@ -145,6 +146,7 @@ def from_broadcasted_tensor_dict( blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), virtual_engine=tensor_dict["virtual_engine"], + num_steps=tensor_dict.pop("num_steps"), ) def as_broadcastable_tensor_dict( @@ -158,6 +160,7 @@ def as_broadcastable_tensor_dict( "blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_copy": self.blocks_to_copy, "virtual_engine": self.virtual_engine, + "num_steps": self.num_steps, } return tensor_dict @@ -216,13 +219,50 @@ def execute_worker(self, worker_input: WorkerInput) -> None: """ raise NotImplementedError - def execute_model( + def _get_worker_input_from_broadcast( + self) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]: + """ Get the worker input from the broadcasted tensor dict. """ + assert self.do_metadata_broadcast + assert not self.is_driver_worker + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) + model_input = ( + self.model_runner.make_model_input_from_broadcasted_tensor_dict( + broadcast_data)) + + return model_input, worker_input + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[ModelRunnerInputBase, WorkerInput]: + """ Get the driver input and broadcast it to other workers. """ + assert self.is_driver_worker + + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + return model_input, worker_input + + def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences, unless no - sequences are provided.""" - start_time = time.perf_counter() + ) -> Optional[Tuple[ModelRunnerInputBase, WorkerInput]]: + """ + Prepare the inputs to ModelRunner and workers. + """ if self.is_driver_worker: if execute_model_req is None: if self.do_metadata_broadcast: @@ -233,34 +273,24 @@ def execute_model( # notify all other workers to stop their execution loop. broadcast_tensor_dict({}, src=0) return None - - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - num_steps = execute_model_req.num_steps - - if self.do_metadata_broadcast: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update( - model_input.as_broadcastable_tensor_dict()) - broadcast_data["num_steps"] = num_steps - broadcast_tensor_dict(broadcast_data, src=0) + return self._get_driver_input_and_broadcast(execute_model_req) else: - assert self.do_metadata_broadcast - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None + return self._get_worker_input_from_broadcast() + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" + start_time = time.perf_counter() + + inputs = self.prepare_input(execute_model_req) + if inputs is None: + return None - num_steps = broadcast_data.pop("num_steps") - worker_input = WorkerInput.from_broadcasted_tensor_dict( - broadcast_data) - model_input = ( - self.model_runner. - make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + model_input, worker_input = inputs + num_steps = worker_input.num_steps self.execute_worker(worker_input)