Skip to content

Commit

Permalink
NvmlCudaPlatform: remove some internal methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Conroy Cheers committed Nov 26, 2024
1 parent 7b024d6 commit 6644c52
Showing 1 changed file with 25 additions and 33 deletions.
58 changes: 25 additions & 33 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import os
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Callable, List, Tuple, TypeVar
from typing import TYPE_CHECKING, Callable, List, TypeVar

import pynvml
import torch
Expand Down Expand Up @@ -57,6 +57,19 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return device_id


def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:

@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()

return wrapper


class CudaPlatformBase(Platform):
_enum = PlatformEnum.CUDA
device_type: str = "cuda"
Expand Down Expand Up @@ -103,36 +116,32 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# the major benefit of using NVML is that it will not initialize CUDA
class NvmlCudaPlatform(CudaPlatformBase):

@staticmethod
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:

@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
pynvml.nvmlInit()
try:
return fn(*args, **kwargs)
finally:
pynvml.nvmlShutdown()

return wrapper

@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
physical_device_id = device_id_to_physical_device_id(device_id)
major, minor = cls._get_physical_device_capability(physical_device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
return DeviceCapability(major=major, minor=minor)

@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_name(cls, device_id: int = 0) -> str:
physical_device_id = device_id_to_physical_device_id(device_id)
return cls._get_physical_device_name(physical_device_id)

@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_total_memory(cls, device_id: int = 0) -> int:
physical_device_id = device_id_to_physical_device_id(device_id)
return cls._get_physical_device_total_memory(physical_device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)

@classmethod
@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
"""
query if the set of gpus are fully connected by nvlink (1 hop)
Expand All @@ -159,27 +168,10 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
return True

@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def _get_physical_device_capability(cls,
device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)

@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def _get_physical_device_name(cls, device_id: int = 0) -> str:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetName(handle)

@classmethod
@lru_cache(maxsize=8)
@with_nvml_context
def _get_physical_device_total_memory(cls, device_id: int = 0) -> int:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)

@classmethod
@with_nvml_context
def log_warnings(cls):
Expand Down

0 comments on commit 6644c52

Please sign in to comment.