From 6644c521af5aeceaefdbecd7dd4c9b22560741e1 Mon Sep 17 00:00:00 2001 From: Conroy Cheers Date: Tue, 26 Nov 2024 16:46:11 +1100 Subject: [PATCH] NvmlCudaPlatform: remove some internal methods --- vllm/platforms/cuda.py | 58 ++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 33 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 3e75031fb8bd1..0d07050fd1b6a 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,7 +4,7 @@ import os from functools import lru_cache, wraps -from typing import TYPE_CHECKING, Callable, List, Tuple, TypeVar +from typing import TYPE_CHECKING, Callable, List, TypeVar import pynvml import torch @@ -57,6 +57,19 @@ def device_id_to_physical_device_id(device_id: int) -> int: return device_id +def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: + + @wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() + + return wrapper + + class CudaPlatformBase(Platform): _enum = PlatformEnum.CUDA device_type: str = "cuda" @@ -103,36 +116,32 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # the major benefit of using NVML is that it will not initialize CUDA class NvmlCudaPlatform(CudaPlatformBase): - @staticmethod - def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: - - @wraps(fn) - def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: - pynvml.nvmlInit() - try: - return fn(*args, **kwargs) - finally: - pynvml.nvmlShutdown() - - return wrapper - @classmethod + @lru_cache(maxsize=8) + @with_nvml_context def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: physical_device_id = device_id_to_physical_device_id(device_id) - major, minor = cls._get_physical_device_capability(physical_device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle) return DeviceCapability(major=major, minor=minor) @classmethod + @lru_cache(maxsize=8) + @with_nvml_context def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) return cls._get_physical_device_name(physical_device_id) @classmethod + @lru_cache(maxsize=8) + @with_nvml_context def get_device_total_memory(cls, device_id: int = 0) -> int: physical_device_id = device_id_to_physical_device_id(device_id) - return cls._get_physical_device_total_memory(physical_device_id) + handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id) + return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) @classmethod + @with_nvml_context def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) @@ -159,27 +168,10 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: return True @classmethod - @lru_cache(maxsize=8) - @with_nvml_context - def _get_physical_device_capability(cls, - device_id: int = 0) -> Tuple[int, int]: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - return pynvml.nvmlDeviceGetCudaComputeCapability(handle) - - @classmethod - @lru_cache(maxsize=8) - @with_nvml_context def _get_physical_device_name(cls, device_id: int = 0) -> str: handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) return pynvml.nvmlDeviceGetName(handle) - @classmethod - @lru_cache(maxsize=8) - @with_nvml_context - def _get_physical_device_total_memory(cls, device_id: int = 0) -> int: - handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) - return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) - @classmethod @with_nvml_context def log_warnings(cls):