From 28b3a1c7e596c08efac0fcfa59a629d16197be30 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Tue, 10 Dec 2024 01:28:14 -0500 Subject: [PATCH] [V1] Multiprocessing Tensor Parallel Support for v1 (#9856) Signed-off-by: Tyler Michael Smith --- .../test_basic_correctness.py | 16 + tests/conftest.py | 11 +- .../device_communicators/shm_broadcast.py | 76 ++-- vllm/executor/multiproc_gpu_executor.py | 47 +-- vllm/executor/multiproc_worker_utils.py | 42 ++ .../model_executor/layers/logits_processor.py | 5 +- vllm/platforms/cuda.py | 28 +- vllm/utils.py | 26 ++ vllm/v1/core/scheduler.py | 4 +- vllm/v1/engine/async_llm.py | 18 +- vllm/v1/engine/core.py | 74 ++-- vllm/v1/engine/core_client.py | 13 +- vllm/v1/engine/llm_engine.py | 19 +- vllm/v1/executor/abstract.py | 48 +++ vllm/v1/executor/multiproc_executor.py | 375 ++++++++++++++++++ .../{gpu_executor.py => uniproc_executor.py} | 12 +- vllm/v1/outputs.py | 6 +- vllm/v1/sample/sampler.py | 3 +- vllm/v1/utils.py | 33 +- vllm/v1/worker/gpu_model_runner.py | 12 +- vllm/v1/worker/gpu_worker.py | 11 +- 21 files changed, 733 insertions(+), 146 deletions(-) create mode 100644 vllm/v1/executor/abstract.py create mode 100644 vllm/v1/executor/multiproc_executor.py rename vllm/v1/executor/{gpu_executor.py => uniproc_executor.py} (90%) diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index fcba253d159f3..11d05cefb7313 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -26,6 +26,14 @@ TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4") +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_vllm_gc_ed(): """Verify vllm instance is GC'ed when it is deleted""" llm = LLM("facebook/opt-125m") @@ -36,6 +44,7 @@ def test_vllm_gc_ed(): assert weak_llm() is None +@pytest.mark.skip_v1 @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) @pytest.mark.parametrize("dtype", ["half"]) @@ -118,6 +127,11 @@ def test_models_distributed( if attention_backend: os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend + # Import VLLM_USE_V1 dynamically to handle patching + from vllm.envs import VLLM_USE_V1 + if VLLM_USE_V1 and distributed_executor_backend != "mp": + pytest.skip(f"Skip {distributed_executor_backend} for V1") + dtype = "half" max_tokens = 5 @@ -143,6 +157,7 @@ def test_models_distributed( ) +@pytest.mark.skip_v1 def test_model_with_failure(vllm_runner) -> None: try: with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward", @@ -169,6 +184,7 @@ def test_model_with_failure(vllm_runner) -> None: os.remove(filename) +@pytest.mark.skip_v1 def test_failure_with_async_out_proc(vllm_runner) -> None: filename = None diff --git a/tests/conftest.py b/tests/conftest.py index d6be8f5b00af8..7606e0f11dfeb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ from enum import Enum from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, TypedDict, TypeVar, Union) -from unittest.mock import patch import numpy as np import pytest @@ -110,7 +109,7 @@ def prompts(self, prompts: _VideoAssetPrompts) -> List[str]: @pytest.fixture(params=[True, False]) -def run_with_both_engines(request): +def run_with_both_engines(request, monkeypatch): # Automatically runs tests twice, once with V1 and once without use_v1 = request.param # Tests decorated with `@skip_v1` are only run without v1 @@ -119,11 +118,11 @@ def run_with_both_engines(request): if use_v1: if skip_v1: pytest.skip("Skipping test on vllm V1") - with patch('vllm.envs.VLLM_USE_V1', True): - yield + monkeypatch.setenv('VLLM_USE_V1', '1') else: - with patch('vllm.envs.VLLM_USE_V1', False): - yield + monkeypatch.setenv('VLLM_USE_V1', '0') + + yield @pytest.fixture(autouse=True) diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 2ff1a1ead99c1..9a2d8918d96e5 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -1,10 +1,11 @@ import os import pickle +import sys import time from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory -from typing import List, Optional +from typing import List, Optional, Tuple from unittest.mock import patch import torch @@ -21,6 +22,20 @@ logger = init_logger(__name__) +# We prefer to use os.sched_yield as it results in tighter polling loops, +# measured to be around 3e-7 seconds. However on earlier versions of Python +# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0) +USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1)) + or (sys.version_info[:2] == (3, 10) + and sys.version_info[2] >= 8)) + + +def sched_yield(): + if USE_SCHED_YIELD: + os.sched_yield() + else: + time.sleep(0) + class ShmRingBuffer: @@ -114,11 +129,14 @@ def __init__(self, # and we should suppress the error pass + def handle(self): + return (self.n_reader, self.max_chunk_bytes, self.max_chunks, + self.shared_memory.name) + def __reduce__(self): return ( self.__class__, - (self.n_reader, self.max_chunk_bytes, self.max_chunks, - self.shared_memory.name), + self.handle(), ) def __del__(self): @@ -147,7 +165,7 @@ class Handle: connect_ip: str local_reader_ranks: List[int] = field(default_factory=list) - buffer: Optional[ShmRingBuffer] = None + buffer_handle: Optional[Tuple[int, int, int, str]] = None local_subscribe_port: Optional[int] = None remote_subscribe_port: Optional[int] = None @@ -228,7 +246,7 @@ def __init__( self.handle = Handle( connect_ip=connect_ip, local_reader_ranks=local_reader_ranks, - buffer=self.buffer, + buffer_handle=self.buffer.handle(), local_subscribe_port=local_subscribe_port, remote_subscribe_port=remote_subscribe_port, ) @@ -247,8 +265,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": context = Context() if rank in handle.local_reader_ranks: - assert handle.buffer is not None - self.buffer = handle.buffer + assert handle.buffer_handle is not None + self.buffer = ShmRingBuffer(*handle.buffer_handle) self.current_idx = 0 self.local_reader_rank = handle.local_reader_ranks.index(rank) self._is_local_reader = True @@ -314,7 +332,7 @@ def wait_until_ready(self): assert recv == b"READY" @contextmanager - def acquire_write(self): + def acquire_write(self, timeout: Optional[float] = None): assert self._is_writer, "Only writers can acquire write" start_time = time.monotonic() n_warning = 1 @@ -329,16 +347,20 @@ def acquire_write(self): # we need to wait until it is read by all readers # Release the processor to other threads - os.sched_yield() + sched_yield() - # if we wait for a long time, we should warn the user + # if we wait for a long time, log a message if (time.monotonic() - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): - logger.warning( - "No available block found in %s second. ", - VLLM_RINGBUFFER_WARNING_INTERVAL) + logger.debug("No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) n_warning += 1 + # if we time out, raise an exception + if (timeout is not None + and time.monotonic() - start_time > timeout): + raise TimeoutError + continue # found a block that is either # (1) not written @@ -365,7 +387,7 @@ def acquire_write(self): break @contextmanager - def acquire_read(self): + def acquire_read(self, timeout: Optional[float] = None): assert self._is_local_reader, "Only readers can acquire read" start_time = time.monotonic() n_warning = 1 @@ -383,16 +405,20 @@ def acquire_read(self): # we need to wait until it is written # Release the processor to other threads - os.sched_yield() + sched_yield() - # if we wait for a long time, we should warn the user + # if we wait for a long time, log a message if (time.monotonic() - start_time > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): - logger.warning( - "No available block found in %s second. ", - VLLM_RINGBUFFER_WARNING_INTERVAL) + logger.debug("No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) n_warning += 1 + # if we time out, raise an exception + if (timeout is not None + and time.monotonic() - start_time > timeout): + raise TimeoutError + continue # found a block that is not read by this reader # let caller read from the buffer @@ -406,24 +432,26 @@ def acquire_read(self): 1) % self.buffer.max_chunks break - def enqueue(self, obj): + def enqueue(self, obj, timeout: Optional[float] = None): + """ Write to message queue with optional timeout (in seconds) """ assert self._is_writer, "Only writers can enqueue" serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) if self.n_local_reader > 0: if len(serialized_obj) >= self.buffer.max_chunk_bytes: - with self.acquire_write() as buf: + with self.acquire_write(timeout) as buf: buf[0] = 1 # overflow self.local_socket.send(serialized_obj) else: - with self.acquire_write() as buf: + with self.acquire_write(timeout) as buf: buf[0] = 0 # not overflow buf[1:len(serialized_obj) + 1] = serialized_obj if self.n_remote_reader > 0: self.remote_socket.send(serialized_obj) - def dequeue(self): + def dequeue(self, timeout: Optional[float] = None): + """ Read from message queue with optional timeout (in seconds) """ if self._is_local_reader: - with self.acquire_read() as buf: + with self.acquire_read(timeout) as buf: overflow = buf[0] == 1 if not overflow: # no need to know the size of serialized object diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index c450209f0eb91..fc58163cade64 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -3,25 +3,19 @@ from functools import partial from typing import Any, List, Optional -import torch - from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.gpu_executor import create_worker -from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, - ResultHandler, WorkerMonitor) +from vllm.executor.multiproc_worker_utils import ( + ProcessWorkerWrapper, ResultHandler, WorkerMonitor, + set_multiprocessing_worker_envs) from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest -from vllm.triton_utils.importing import HAS_TRITON from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, - cuda_is_initialized, get_distributed_init_method, - get_open_port, make_async, + get_distributed_init_method, get_open_port, make_async, update_environment_variables) -if HAS_TRITON: - from vllm.triton_utils import maybe_set_triton_cache_manager - logger = init_logger(__name__) @@ -37,30 +31,8 @@ def _init_executor(self) -> None: world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size - # Disable torch async compiling which won't work with daemonic processes - os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" - - # Configure thread parallelism if OMP_NUM_THREADS isn't set - # - # Helps to avoid CPU contention. The default of spawning a thread per - # core combined with multiprocessing for each GPU can have a negative - # impact on performance. The contention is amplified when running in a - # container where CPU limits can cause throttling. - default_omp_num_threads = 1 - if "OMP_NUM_THREADS" not in os.environ and ( - current_parallelism := - torch.get_num_threads()) > default_omp_num_threads: - logger.warning( - "Reducing Torch parallelism from %d threads to %d to avoid " - "unnecessary CPU contention. Set OMP_NUM_THREADS in the " - "external environment to tune this value as needed.", - current_parallelism, default_omp_num_threads) - os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) - torch.set_num_threads(default_omp_num_threads) - - # workaround for https://github.com/vllm-project/vllm/issues/6103 - if HAS_TRITON and world_size > 1: - maybe_set_triton_cache_manager() + # Set multiprocessing envs that are common to V0 and V1 + set_multiprocessing_worker_envs(self.parallel_config) # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address @@ -122,13 +94,6 @@ def _check_executor_parameters(self): "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) }) - if (cuda_is_initialized() - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"): - logger.warning("CUDA was previously initialized. We must use " - "the `spawn` multiprocessing start method. Setting " - "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.") - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - cuda_device_count = cuda_device_count_stateless() # Use confusing message for more common TP-only case. assert tensor_parallel_size <= cuda_device_count, ( diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py index 884267d23dfc8..fe475db6d3f57 100644 --- a/vllm/executor/multiproc_worker_utils.py +++ b/vllm/executor/multiproc_worker_utils.py @@ -11,8 +11,15 @@ from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union) +import torch + import vllm.envs as envs from vllm.logger import init_logger +from vllm.triton_utils.importing import HAS_TRITON +from vllm.utils import cuda_is_initialized + +if HAS_TRITON: + from vllm.triton_utils import maybe_set_triton_cache_manager logger = init_logger(__name__) @@ -270,3 +277,38 @@ def write_with_prefix(s: str): def get_mp_context(): mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD return multiprocessing.get_context(mp_method) + + +def set_multiprocessing_worker_envs(parallel_config): + """ Set up environment variables that should be used when there are workers + in a multiprocessing environment. This should be called by the parent + process before worker processes are created""" + + if (cuda_is_initialized() + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"): + logger.warning("CUDA was previously initialized. We must use " + "the `spawn` multiprocessing start method. Setting " + "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + # Configure thread parallelism if OMP_NUM_THREADS isn't set + # + # Helps to avoid CPU contention. The default of spawning a thread per + # core combined with multiprocessing for each GPU can have a negative + # impact on performance. The contention is amplified when running in a + # container where CPU limits can cause throttling. + default_omp_num_threads = 1 + if "OMP_NUM_THREADS" not in os.environ and ( + current_parallelism := + torch.get_num_threads()) > default_omp_num_threads: + logger.warning( + "Reducing Torch parallelism from %d threads to %d to avoid " + "unnecessary CPU contention. Set OMP_NUM_THREADS in the " + "external environment to tune this value as needed.", + current_parallelism, default_omp_num_threads) + os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) + torch.set_num_threads(default_omp_num_threads) + + # workaround for https://github.com/vllm-project/vllm/issues/6103 + if HAS_TRITON and parallel_config.world_size > 1: + maybe_set_triton_cache_manager() diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index fb76b1b17925e..2bc7e458494f7 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +import vllm.envs as envs from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -42,7 +43,9 @@ def __init__(self, # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - self.use_gather = not current_platform.is_tpu() + + self.use_gather = not current_platform.is_tpu( + ) and not envs.VLLM_USE_V1 def forward( self, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index edaf377b501df..10f83fd304281 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -12,6 +12,7 @@ # import custom ops, trigger op registration import vllm._C # noqa +import vllm.envs as envs from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum @@ -110,17 +111,28 @@ def log_warnings(cls): def check_and_update_config(cls, vllm_config: VllmConfig) -> None: parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": if scheduler_config.is_multi_step: - parallel_config.worker_cls = \ - "vllm.worker.multi_step_worker.MultiStepWorker" + if envs.VLLM_USE_V1: + raise NotImplementedError + else: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: - parallel_config.worker_cls = \ - "vllm.spec_decode.spec_decode_worker.create_spec_worker" - parallel_config.sd_worker_cls = \ - "vllm.worker.worker.Worker" + if envs.VLLM_USE_V1: + raise NotImplementedError + else: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" else: - parallel_config.worker_cls = "vllm.worker.worker.Worker" + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.gpu_worker.Worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" # NVML utils @@ -249,4 +261,4 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: if not isinstance(pynvml, _MockModule): CudaPlatform.log_warnings() except ModuleNotFoundError: - CudaPlatform.log_warnings() \ No newline at end of file + CudaPlatform.log_warnings() diff --git a/vllm/utils.py b/vllm/utils.py index 2bb1fb2af40f4..7cdb2cb320b05 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -10,6 +10,7 @@ import inspect import ipaddress import os +import signal import socket import subprocess import sys @@ -1652,3 +1653,28 @@ def resolve_obj_by_qualname(qualname: str) -> Any: module_name, obj_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) return getattr(module, obj_name) + + +def kill_process_tree(pid: int): + """ + Kills all descendant processes of the given pid by sending SIGKILL. + + Args: + pid (int): Process ID of the parent process + """ + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + # Get all children recursively + children = parent.children(recursive=True) + + # Send SIGKILL to all children first + for child in children: + with contextlib.suppress(ProcessLookupError): + os.kill(child.pid, signal.SIGKILL) + + # Finally kill the parent + with contextlib.suppress(ProcessLookupError): + os.kill(pid, signal.SIGKILL) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 1203d35fc985f..a3e85c20cc664 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -5,6 +5,8 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.logger import init_logger +from vllm.multimodal import MultiModalKwargs +from vllm.multimodal.base import PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager @@ -383,7 +385,7 @@ def update_from_output( model_runner_output: "ModelRunnerOutput", ) -> List[EngineCoreOutput]: # NOTE(woosuk): This method doesn't consider speculative decoding. - sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist() + sampled_token_ids = model_runner_output.sampled_token_ids num_scheduled_tokens = scheduler_output.num_scheduled_tokens new_running: List[Request] = [] engine_core_outputs: List[EngineCoreOutput] = [] diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 0bcccda2bf329..26fd650aee4b7 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -20,7 +20,7 @@ from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor -from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.abstract import Executor logger = init_logger(__name__) @@ -30,7 +30,7 @@ class AsyncLLM(EngineClient): def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Executor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -119,14 +119,24 @@ def from_engine_args( def shutdown(self): """Shutdown, cleaning up the background proc and IPC.""" - self.engine_core.shutdown() + if engine_core := getattr(self, "engine_core", None): + engine_core.shutdown() if handler := getattr(self, "output_handler", None): handler.cancel() @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): - return GPUExecutor + distributed_executor_backend = ( + vllm_config.parallel_config.distributed_executor_backend) + if distributed_executor_backend == "mp": + from vllm.v1.executor.multiproc_executor import MultiprocExecutor + executor_class = MultiprocExecutor + else: + assert (distributed_executor_backend is None) + from vllm.v1.executor.uniproc_executor import UniprocExecutor + executor_class = UniprocExecutor + return executor_class async def add_request( self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 751eb3b40a68d..fdb241e6753fb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,12 +1,12 @@ import multiprocessing import pickle import queue +import signal import threading import time -from contextlib import contextmanager from multiprocessing.process import BaseProcess from multiprocessing.sharedctypes import Synchronized -from typing import Any, Iterator, List, Tuple, Type, Union +from typing import List, Tuple, Type, Union import zmq import zmq.asyncio @@ -20,9 +20,10 @@ EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType) from vllm.v1.engine.mm_input_mapper import MMInputMapper -from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.abstract import Executor from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import PickleEncoder +from vllm.v1.utils import make_zmq_socket from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -38,7 +39,7 @@ class EngineCore: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Executor], usage_context: UsageContext, ): assert vllm_config.model_config.task != "embedding" @@ -80,7 +81,7 @@ def _initialize_kv_caches(self, num_gpu_blocks = num_gpu_blocks_override num_cpu_blocks = 0 - self.model_executor.initialize_cache(num_gpu_blocks) + self.model_executor.initialize(num_gpu_blocks) elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " "warmup model) took %.2f seconds"), elapsed) @@ -112,8 +113,11 @@ def step(self) -> List[EngineCoreOutput]: scheduler_output, output) return engine_core_outputs + def shutdown(self): + self.model_executor.shutdown() + def profile(self, is_start=True): - self.model_executor.worker.profile(is_start) + self.model_executor.profile(is_start) class EngineCoreProc(EngineCore): @@ -124,7 +128,7 @@ class EngineCoreProc(EngineCore): def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Executor], usage_context: UsageContext, input_path: str, output_path: str, @@ -151,32 +155,9 @@ def __init__( daemon=True).start() # Send Readiness signal to EngineClient. - with self.make_socket(ready_path, zmq.constants.PUSH) as ready_socket: + with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket: ready_socket.send_string(EngineCoreProc.READY_STR) - @contextmanager - def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]: - """Context manager for use """ - - ctx = zmq.Context() - try: - socket = ctx.socket(type) - - if type == zmq.constants.PULL: - socket.connect(path) - elif type == zmq.constants.PUSH: - socket.bind(path) - else: - raise ValueError(f"Unknown Socket Type: {type}") - - yield socket - - except KeyboardInterrupt: - logger.debug("EngineCore had Keyboard Interrupt.") - - finally: - ctx.destroy(linger=0) - @staticmethod def wait_for_startup( proc: BaseProcess, @@ -209,7 +190,7 @@ def wait_for_startup( @staticmethod def make_engine_core_process( vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Executor], usage_context: UsageContext, input_path: str, output_path: str, @@ -244,17 +225,38 @@ def make_engine_core_process( def run_engine_core(*args, **kwargs): """Launch EngineCore busy loop in background process.""" + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the engine_core + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + engine_core = None try: engine_core = EngineCoreProc(*args, **kwargs) engine_core.run_busy_loop() - except KeyboardInterrupt: + except SystemExit: logger.debug("EngineCore interrupted.") except BaseException as e: logger.exception(e) raise e + finally: + if engine_core is not None: + engine_core.shutdown() + engine_core = None + def run_busy_loop(self): """Core busy loop of the EngineCore.""" @@ -272,6 +274,8 @@ def run_busy_loop(self): logger.debug("EngineCore busy loop waiting.") if self.should_shutdown: return + except BaseException: + raise # 2) Handle any new client requests (Abort or Add). while not self.input_queue.empty(): @@ -321,7 +325,7 @@ def process_input_socket(self, input_path: str): decoder_add_req = PickleEncoder() decoder_abort_req = PickleEncoder() - with self.make_socket(input_path, zmq.constants.PULL) as socket: + with make_zmq_socket(input_path, zmq.constants.PULL) as socket: while True: # (RequestType, RequestData) type_frame, data_frame = socket.recv_multipart(copy=False) @@ -349,7 +353,7 @@ def process_output_socket(self, output_path: str): # Reuse send buffer. buffer = bytearray() - with self.make_socket(output_path, zmq.constants.PUSH) as socket: + with make_zmq_socket(output_path, zmq.constants.PUSH) as socket: while True: engine_core_outputs = self.output_queue.get() outputs = EngineCoreOutputs(outputs=engine_core_outputs) diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 835963f7ee86c..ee89cece73141 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,5 +1,4 @@ import multiprocessing -import time from typing import List, Union import msgspec @@ -7,7 +6,7 @@ import zmq.asyncio from vllm.logger import init_logger -from vllm.utils import get_open_zmq_ipc_path +from vllm.utils import get_open_zmq_ipc_path, kill_process_tree from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType) @@ -99,6 +98,12 @@ def add_request(self, request: EngineCoreRequest) -> None: def abort_requests(self, request_ids: List[str]) -> None: self.engine_core.abort_requests(request_ids) + def shutdown(self): + self.engine_core.shutdown() + + def __del__(self): + self.shutdown() + async def profile(self, is_start=True) -> None: self.engine_core.profile(is_start) @@ -163,10 +168,10 @@ def shutdown(self): # Shutdown the process if needed. if hasattr(self, "proc") and self.proc.is_alive(): self.proc.terminate() + self.proc.join(5) - time.sleep(5) if self.proc.is_alive(): - self.proc.kill() + kill_process_tree(self.proc.pid) def __del__(self): self.shutdown() diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 994e68669108e..1b3a9f12d009e 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -20,7 +20,7 @@ from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor -from vllm.v1.executor.gpu_executor import GPUExecutor +from vllm.v1.executor.abstract import Executor logger = init_logger(__name__) @@ -33,7 +33,7 @@ class LLMEngine: def __init__( self, vllm_config: VllmConfig, - executor_class: Type[GPUExecutor], + executor_class: Type[Executor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -104,10 +104,17 @@ def from_engine_args( @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): - return GPUExecutor - - def stop_remote_worker_execution_loop(self) -> None: - raise NotImplementedError("TP not implemented yet.") + distributed_executor_backend = ( + vllm_config.parallel_config.distributed_executor_backend) + if distributed_executor_backend == "mp": + from vllm.v1.executor.multiproc_executor import MultiprocExecutor + executor_class = MultiprocExecutor + else: + assert (distributed_executor_backend is None) + from vllm.v1.executor.uniproc_executor import UniprocExecutor + executor_class = UniprocExecutor + + return executor_class def get_num_unfinished_requests(self) -> int: return self.detokenizer.get_num_unfinished_requests() diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py new file mode 100644 index 0000000000000..9cd267581ad18 --- /dev/null +++ b/vllm/v1/executor/abstract.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional, Tuple + +from vllm.config import VllmConfig +from vllm.v1.outputs import ModelRunnerOutput + + +class Executor(ABC): + """Abstract class for executors.""" + + @abstractmethod + def __init__(self, vllm_config: VllmConfig) -> None: + raise NotImplementedError + + @abstractmethod + def initialize(self, num_gpu_blocks: int) -> None: + raise NotImplementedError + + @abstractmethod + def determine_num_available_blocks(self) -> Tuple[int, int]: + raise NotImplementedError + + @abstractmethod + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + raise NotImplementedError + + @abstractmethod + def profile(self, is_start=True): + raise NotImplementedError + + @abstractmethod + def shutdown(self): + pass + + @abstractmethod + def check_health(self) -> None: + raise NotImplementedError + + @abstractmethod + def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> []: + raise NotImplementedError diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py new file mode 100644 index 0000000000000..f8f3d583618cf --- /dev/null +++ b/vllm/v1/executor/multiproc_executor.py @@ -0,0 +1,375 @@ +import atexit +import os +import pickle +import signal +import sys +import time +from dataclasses import dataclass +from enum import Enum, auto +from multiprocessing.process import BaseProcess +from typing import Dict, List, Optional, Tuple + +import zmq + +from vllm.config import VllmConfig +from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel) +from vllm.distributed.device_communicators.shm_broadcast import (Handle, + MessageQueue) +from vllm.executor.multiproc_worker_utils import ( + _add_prefix, get_mp_context, set_multiprocessing_worker_envs) +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, get_open_port, + get_open_zmq_ipc_path) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.utils import make_zmq_socket +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 5000 +POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000 + + +class MultiprocExecutor: + + def __init__(self, vllm_config: VllmConfig) -> None: + # Call self.shutdown at exit to clean up + # and ensure workers will be terminated. + atexit.register(self.shutdown) + + self.vllm_config = vllm_config + self.parallel_config = vllm_config.parallel_config + + self.world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + assert self.world_size == tensor_parallel_size, ( + f"world_size ({self.world_size}) must be equal to the " + f"tensor_parallel_size ({tensor_parallel_size}). " + f"Pipeline parallelism is not yet implemented in v1") + + # Set multiprocessing envs that are common to V0 and V1 + set_multiprocessing_worker_envs(self.parallel_config) + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + # Initialize worker and set up message queues for SchedulerOutputs + # and ModelRunnerOutputs + self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size) + scheduler_output_handle = self.rpc_broadcast_mq.export_handle() + + # Create workers + self.workers: List[WorkerProcHandle] = [] + for rank in range(self.world_size): + worker = WorkerProc.make_worker_process(vllm_config, rank, rank, + distributed_init_method, + scheduler_output_handle) + self.workers.append(worker) + + # Ensure message queues are ready. Will deadlock if re-ordered + # Must be kept consistent with the WorkerProc + self.rpc_broadcast_mq.wait_until_ready() + for w in self.workers: + w.worker_response_mq.wait_until_ready() + + def initialize(self, num_gpu_blocks: int) -> None: + """ + Initialize the KV caches and begin the model execution loop of the + underlying workers. + """ + self.collective_rpc("initialize_cache", args=(num_gpu_blocks, )) + self.collective_rpc("compile_or_warm_up_model") + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """ + Determine the number of available KV blocks by invoking the + underlying worker. + """ + num_blocks = self.collective_rpc("determine_num_available_blocks") + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> []: + """ + Execute an RPC call on workers. + + Args: + method: Name of the worker method to execute + timeout: Maximum time in seconds to wait for execution. Rases a + TimeoutError on timeout. None means wait indefinitely. + args: Positional arguments to pass to the worker method + kwargs: Keyword arguments to pass to the worker method + + Returns: + List of results from each worker + """ + start_time = time.monotonic() + kwargs = kwargs or {} + + try: + self.rpc_broadcast_mq.enqueue((method, args, kwargs)) + + responses = [None] * self.world_size + for w in self.workers: + dequeue_timeout = timeout - (time.monotonic() - start_time() + ) if timeout is not None else None + status, result = w.worker_response_mq.dequeue( + timeout=dequeue_timeout) + + if status != WorkerProc.ResponseStatus.SUCCESS: + if isinstance(result, Exception): + raise result + else: + raise RuntimeError("Worker failed") + + responses[w.rank] = result + + return responses + except TimeoutError as e: + raise TimeoutError(f"RPC call to {method} timed out.") from e + except Exception as e: + # Re-raise any other exceptions + raise e + + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + model_output = self.collective_rpc("execute_model", + args=(scheduler_output, ))[0] + return model_output + + def profile(self, is_start=True): + self.collective_rpc("profile", args=(is_start, )) + return + + def _ensure_worker_termination(self): + """Ensure that all worker processes are terminated. Assumes workers have + received termination requests. Waits for processing, then sends + termination and kill signals if needed.""" + + def wait_for_termination(procs, timeout): + start_time = time.time() + while time.time() - start_time < timeout: + if all(not proc.is_alive() for proc in procs): + return True + time.sleep(0.1) + return False + + # Send SIGTERM if still running + active_procs = [w.proc for w in self.workers if w.proc.is_alive()] + self.workers = None + for p in active_procs: + p.terminate() + if wait_for_termination(active_procs, 4): + return + + # Send SIGKILL if still running + active_procs = [p for p in active_procs if p.is_alive()] + for p in active_procs: + p.kill() + + def shutdown(self): + """Properly shut down the executor and its workers""" + if (hasattr(self, 'workers') and self.workers is not None): + for w in self.workers: #TODO: not sure if needed + w.worker_response_mq = None + self._ensure_worker_termination() + + self.rpc_broadcast_mq = None + + def check_health(self) -> None: + self.collective_rpc("check_health", timeout=10) + return + + +@dataclass +class WorkerProcHandle: + proc: BaseProcess + rank: int + ready_path: str + worker_response_mq: MessageQueue # The worker process writes to this MQ + + +class WorkerProc: + """Wrapper that runs one Worker in a separate process.""" + + READY_STR = "READY" + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle: Handle, + ready_path: str, + ): + self.rank = rank + wrapper = WorkerWrapperBase(vllm_config=vllm_config) + wrapper.init_worker(vllm_config, local_rank, rank, + distributed_init_method) + self.worker = wrapper.worker + + pid = os.getpid() + _add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid) + _add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid) + + # Initialize MessageQueue for receiving SchedulerOutput + self.rpc_broadcast_mq = MessageQueue.create_from_handle( + input_shm_handle, self.worker.rank) + + # Initializes a message queue for sending the model output + self.worker_response_mq = MessageQueue(1, 1) + worker_response_mq_handle = self.worker_response_mq.export_handle() + + # Send Readiness signal to EngineCore process. + with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket: + payload = pickle.dumps(worker_response_mq_handle, + protocol=pickle.HIGHEST_PROTOCOL) + ready_socket.send_string(WorkerProc.READY_STR) + ready_socket.send(payload) + + self.worker.initialize() + self.worker.load_model() + + @staticmethod + def make_worker_process( + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + input_shm_handle, # Receive SchedulerOutput + ) -> WorkerProcHandle: + context = get_mp_context() + + # ZMQ path for worker to send ready message and shm_broadcast handle + # back to core process. + ready_path = get_open_zmq_ipc_path() + + process_kwargs = { + "vllm_config": vllm_config, + "local_rank": local_rank, + "rank": rank, + "distributed_init_method": distributed_init_method, + "input_shm_handle": input_shm_handle, + "ready_path": ready_path, + } + # Run EngineCore busy loop in background process. + proc = context.Process(target=WorkerProc.worker_main, + kwargs=process_kwargs, + daemon=True) + proc.start() + + # Wait for startup + worker_response_mq_handle = WorkerProc.wait_for_startup( + proc, ready_path) + + worker_response_mq = MessageQueue.create_from_handle( + worker_response_mq_handle, 0) + + return WorkerProcHandle(proc, rank, ready_path, worker_response_mq) + + def shutdown(self): + self.rpc_broadcast_mq = None + self.worker_response_mq = None + destroy_model_parallel() + destroy_distributed_environment() + + @staticmethod + def worker_main(*args, **kwargs): + """ Worker initialization and execution loops. + This runs a background process """ + + # Signal handler used for graceful termination. + # SystemExit exception is only raised once to allow this and worker + # processes to terminate without error + shutdown_requested = False + + def signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the worker + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + worker = None + try: + worker = WorkerProc(*args, **kwargs) + + # Ensure message queues are ready. Will deadlock if re-ordered. + # Must be kept consistent with the Executor + worker.rpc_broadcast_mq.wait_until_ready() + worker.worker_response_mq.wait_until_ready() + + worker.worker_busy_loop() + + except SystemExit: + logger.debug("Worker interrupted.") + + except BaseException as e: + logger.exception(e) + raise + + finally: + # Clean up once worker exits busy loop + if worker is not None: + worker.shutdown() + worker = None + + @staticmethod + def wait_for_startup( + proc: BaseProcess, + ready_path: str, + ) -> Optional[Handle]: + """Wait until the Worker is ready.""" + with make_zmq_socket(ready_path, zmq.constants.PULL) as socket: + + # Wait for Worker to send READY. + while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + logger.debug("Waiting for WorkerProc to startup.") + + if not proc.is_alive(): + raise RuntimeError("WorkerProc failed to start.") + + message = socket.recv_string() + assert message == WorkerProc.READY_STR + handle_frame = socket.recv(copy=False) + handle = pickle.loads(handle_frame.buffer) + return handle + + class ResponseStatus(Enum): + SUCCESS = auto() + FAILURE = auto() + + def worker_busy_loop(self): + """Main busy loop for Multiprocessing Workers""" + while True: + method, args, kwargs = self.rpc_broadcast_mq.dequeue() + + try: + output = getattr(self.worker, method)(*args, **kwargs) + except BaseException as e: + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.FAILURE, e)) + continue + + self.worker_response_mq.enqueue( + (WorkerProc.ResponseStatus.SUCCESS, output)) diff --git a/vllm/v1/executor/gpu_executor.py b/vllm/v1/executor/uniproc_executor.py similarity index 90% rename from vllm/v1/executor/gpu_executor.py rename to vllm/v1/executor/uniproc_executor.py index f71fa16b16e27..9b1d9a40950c6 100644 --- a/vllm/v1/executor/gpu_executor.py +++ b/vllm/v1/executor/uniproc_executor.py @@ -10,7 +10,7 @@ logger = init_logger(__name__) -class GPUExecutor: +class UniprocExecutor: def __init__(self, vllm_config: VllmConfig) -> None: self.vllm_config = vllm_config @@ -54,7 +54,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ return self.worker.determine_num_available_blocks() - def initialize_cache(self, num_gpu_blocks: int) -> None: + def initialize(self, num_gpu_blocks: int) -> None: """Initialize the KV cache by invoking the underlying worker. """ # NOTE: This is logged in the executor because there can be >1 worker @@ -71,7 +71,13 @@ def execute_model( output = self.worker.execute_model(scheduler_output) return output + def profile(self, is_start: bool = True): + self.worker.profile(is_start) + + def shutdown(self): + self.worker = None + def check_health(self) -> None: - # GPUExecutor will always be healthy as long as + # UniprocExecutor will always be healthy as long as # it's running. return diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 8574987728844..acc3a944e21b9 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -8,7 +8,7 @@ class SamplerOutput: # [num_reqs] - sampled_token_ids: torch.Tensor + sampled_token_ids: List[int] # [num_reqs, max_num_logprobs + 1] logprob_token_ids: Optional[torch.Tensor] @@ -20,6 +20,8 @@ class SamplerOutput: prompt_logprobs: Optional[torch.Tensor] +# ModelRunnerOutput is serialized and sent to the scheduler process. +# This is expensive for torch.Tensor so prefer to use List instead. @dataclass class ModelRunnerOutput: @@ -29,7 +31,7 @@ class ModelRunnerOutput: req_id_to_index: Dict[str, int] # [num_reqs] - sampled_token_ids_cpu: torch.Tensor + sampled_token_ids: List[int] # [num_reqs, max_num_logprobs + 1] logprob_token_ids_cpu: Optional[torch.Tensor] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 927f274541c4d..d1a755be01ff7 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -37,8 +37,9 @@ def forward( topk_logprobs = None topk_indices = None + # NOTE: CPU-GPU synchronization happens here. sampler_output = SamplerOutput( - sampled_token_ids=sampled, + sampled_token_ids=sampled.tolist(), logprob_token_ids=topk_indices, logprobs=topk_logprobs, prompt_logprob_token_ids=None, diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 4b26749712e32..6e7a7d4fe12cd 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -1,4 +1,11 @@ -from typing import Generic, List, TypeVar, overload +from contextlib import contextmanager +from typing import Any, Generic, Iterator, List, TypeVar, overload + +import zmq + +from vllm.logger import init_logger + +logger = init_logger(__name__) T = TypeVar("T") @@ -62,3 +69,27 @@ def __contains__(self, item): def __len__(self): return len(self._x) + + +@contextmanager +def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]: + """Context manager for a ZMQ socket""" + + ctx = zmq.Context() + try: + socket = ctx.socket(type) + + if type == zmq.constants.PULL: + socket.connect(path) + elif type == zmq.constants.PUSH: + socket.bind(path) + else: + raise ValueError(f"Unknown Socket Type: {type}") + + yield socket + + except KeyboardInterrupt: + logger.debug("Worker had Keyboard Interrupt.") + + finally: + ctx.destroy(linger=0) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c601aca13feaf..0a5adfb28c9bd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -34,6 +34,7 @@ class GPUModelRunner: def __init__( self, vllm_config: VllmConfig, + device: torch.device, input_registry: InputRegistry = INPUT_REGISTRY, ): self.vllm_config = vllm_config @@ -43,7 +44,6 @@ def __init__( self.load_config = vllm_config.load_config self.parallel_config = vllm_config.parallel_config self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config self.speculative_config = vllm_config.speculative_config self.prompt_adapter_config = vllm_config.prompt_adapter_config self.observability_config = vllm_config.observability_config @@ -52,7 +52,7 @@ def __init__( cache_config = self.cache_config scheduler_config = self.scheduler_config parallel_config = self.parallel_config - self.device = self.device_config.device + self.device = device self.pin_memory = is_pin_memory_available() self.dtype = self.model_config.dtype if cache_config.cache_dtype == "auto": @@ -477,9 +477,7 @@ def execute_model( sampling_metadata=sampling_metadata, ) - # NOTE: CPU-GPU synchronization happens here. - sampled_token_ids = sampler_output.sampled_token_ids.cpu() - sampled_token_ids_list = sampled_token_ids.tolist() + sampled_token_ids = sampler_output.sampled_token_ids # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. num_reqs = self.input_batch.num_reqs @@ -490,7 +488,7 @@ def execute_model( assert seq_len <= req_state.num_tokens if seq_len == req_state.num_tokens: # Append the sampled token to the output token ids. - token_id = sampled_token_ids_list[i] + token_id = sampled_token_ids[i] self.input_batch.token_ids_cpu[i, seq_len] = token_id req_state.output_token_ids.append(token_id) else: @@ -512,7 +510,7 @@ def execute_model( model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids[:num_reqs], req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids_cpu=sampled_token_ids, + sampled_token_ids=sampled_token_ids, logprob_token_ids_cpu=logprob_token_ids, logprobs_cpu=logprobs, ) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index d33b55a8a9f9a..d32848c3775ae 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -15,6 +15,7 @@ from vllm.model_executor import set_random_seed from vllm.platforms import current_platform from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.gpu_model_runner import GPUModelRunner @@ -56,7 +57,6 @@ def __init__( from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner = GPUModelRunner(vllm_config) # Torch profiler. Enabled and configured through env vars: # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace if envs.VLLM_TORCH_PROFILER_DIR: @@ -103,6 +103,9 @@ def initialize(self): # Set random seed. set_random_seed(self.model_config.seed) + # Construct the model runner + self.model_runner = GPUModelRunner(self.vllm_config, self.device) + def load_model(self) -> None: self.model_runner.load_model() @@ -198,7 +201,7 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: output = self.model_runner.execute_model(scheduler_output) - # TODO(woosuk): Send the output to the engine process. + return output if self.rank == 0 else None return output def profile(self, is_start=True): @@ -209,6 +212,10 @@ def profile(self, is_start=True): else: self.profiler.stop() + def check_health(self) -> None: + # worker will always be healthy as long as it's running. + return + def init_worker_distributed_environment( parallel_config: ParallelConfig,