diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 05067a6a192d5..6d6d7895b2101 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -107,7 +107,13 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest] +class RPCUProfileRequest(Enum): + START_PROFILE = 1 + STOP_PROFILE = 2 + + +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, + RPCUProfileRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 239ca52ef13e2..700e65000e052 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -21,7 +21,8 @@ IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCProcessRequest, - RPCStartupRequest, RPCStartupResponse) + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.inputs import PromptType @@ -38,10 +39,10 @@ class MQClientClosedError(Exception): """Exception class raised when the client is used post-close. - + The client can be closed, which closes the ZMQ context. This normally - happens on server shutdown. In some cases, methods like abort and - do_log_stats will still be called and then try to open a socket, which + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which causes a ZMQError and creates a huge stack trace. So, we throw this error such that we can suppress it. """ @@ -345,7 +346,7 @@ async def do_log_stats(self): async def check_health(self): """ The check health loop probes the health status of the - Engine's health every N seconds and sets _errored_with + Engine's health every N seconds and sets _errored_with if the engine is unhealthy. """ if self._errored_with is not None: @@ -561,3 +562,15 @@ async def _process_request( await self.abort(request_id) finally: self.output_queues.pop(request_id) + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index b406d4a759667..eecca82cd2f7d 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -18,9 +18,11 @@ IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, RPCError, RPCProcessRequest, - RPCStartupRequest, RPCStartupResponse) + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -249,6 +251,11 @@ def handle_new_input(self): self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) + elif isinstance(request, RPCUProfileRequest): + if request == RPCUProfileRequest.START_PROFILE: + self.start_profile() + else: + self.stop_profile() else: raise ValueError("Unknown RPCRequest Type: " f"{type(request)}") @@ -356,6 +363,18 @@ def _set_errored(self, e: BaseException): def _alive(self): self._last_alive_time = time.time() + def start_profile(self) -> None: + if type(self.engine.model_executor) is GPUExecutor: + self.engine.model_executor.start_profile() + else: + self.engine.model_executor._run_workers("start_profile") + + def stop_profile(self) -> None: + if type(self.engine.model_executor) is GPUExecutor: + self.engine.model_executor.stop_profile() + else: + self.engine.model_executor._run_workers("stop_profile") + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str):