Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adapt vllm distributed module to sglang #2244

Merged
merged 14 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions python/sglang/srt/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
yizhang2077 marked this conversation as resolved.
Show resolved Hide resolved
yizhang2077 marked this conversation as resolved.
Show resolved Hide resolved
import contextlib
yizhang2077 marked this conversation as resolved.
Show resolved Hide resolved
import functools
import importlib
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch
import torch.library

from sglang.srt.utils import is_hpu

# from vllm.scalar_type import ScalarType
zhyncs marked this conversation as resolved.
Show resolved Hide resolved

logger = logging.getLogger(__name__)

if not is_hpu():
try:
import custom_ar
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)


def hint_on_error(fn):

@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)

except NotImplementedError as e:
msg = (
"Error in calling custom op %s: %s\n"
"Not implemented or built, mostly likely because the current current device "
"does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
"incorrectly while building)"
)
logger.error(msg, fn.__name__, e)
raise NotImplementedError(msg % (fn.__name__, e)) from e
except AttributeError as e:
msg = (
"Error in calling custom op %s: %s\n"
"Possibly you have built or installed an obsolete version of vllm.\n"
"Please try a clean build and install of vllm,"
"or remove old built files such as vllm/*cpython*.so and build/ ."
)
logger.error(msg, fn.__name__, e)
raise e

return wrapper


# custom ar
def init_custom_ar(
ipc_tensors: List[torch.Tensor],
rank_data: torch.Tensor,
rank: int,
full_nvlink: bool,
) -> int:
return torch.ops._C_vllm_ar.init_custom_ar(
ipc_tensors, rank_data, rank, full_nvlink
)


def all_reduce(
fa: int,
inp: torch.Tensor,
out: torch.Tensor,
reg_buffer: int,
reg_buffer_sz_bytes: int,
) -> None:
torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)


def dispose(fa: int) -> None:
torch.ops._C_vllm_ar.dispose(fa)


def meta_size() -> int:
return torch.ops._C_vllm_ar.meta_size()


def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)


def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)


def register_graph_buffers(
fa: int, handles: List[List[int]], offsets: List[List[int]]
) -> None:
torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)


# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
# find functions that are defined in this file and have torch.Tensor
# in their annotations. `arg == "torch.Tensor"` is used to handle
# the case when users use `import __annotations__` to turn type
# hints into strings.
if (
isinstance(v, fn_type)
and v.__code__.co_filename == __file__
and any(
arg is torch.Tensor or arg == "torch.Tensor"
for arg in v.__annotations__.values()
)
):
names_and_values_to_update[k] = hint_on_error(v)

names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type
3 changes: 3 additions & 0 deletions python/sglang/srt/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .communication_op import *
yizhang2077 marked this conversation as resolved.
Show resolved Hide resolved
from .parallel_state import *
from .utils import *
34 changes: 34 additions & 0 deletions python/sglang/srt/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
from typing import Any, Dict, Optional, Union

import torch
import torch.distributed

from .parallel_state import get_tp_group
yizhang2077 marked this conversation as resolved.
Show resolved Hide resolved


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)


def tensor_model_parallel_all_gather(
input_: torch.Tensor, dim: int = -1
) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
return get_tp_group().all_gather(input_, dim)


def tensor_model_parallel_gather(
input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]:
"""Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)


def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
):
if not torch.distributed.is_initialized():
return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
Empty file.
182 changes: 182 additions & 0 deletions python/sglang/srt/distributed/device_communicators/cuda_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
"""This file is a pure Python wrapper for the cudart library.
yizhang2077 marked this conversation as resolved.
Show resolved Hide resolved
It avoids the need to compile a separate shared library, and is
convenient for use when we just need to call a few functions.
"""

import ctypes
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa

logger = logging.getLogger(__name__)

# === export types and functions from cudart to Python ===
# for the original cudart definition, please check
# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html

cudaError_t = ctypes.c_int
cudaMemcpyKind = ctypes.c_int


class cudaIpcMemHandle_t(ctypes.Structure):
_fields_ = [("internal", ctypes.c_byte * 128)]


@dataclass
class Function:
name: str
restype: Any
argtypes: List[Any]


def find_loaded_library(lib_name) -> Optional[str]:
"""
According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html,
the file `/proc/self/maps` contains the memory maps of the process, which includes the
shared libraries loaded by the process. We can use this file to find the path of the
a loaded library.
""" # noqa
found = False
with open("/proc/self/maps") as f:
for line in f:
if lib_name in line:
found = True
break
if not found:
# the library is not loaded in the current process
return None
# if lib_name is libcudart, we need to match a line with:
# address /path/to/libcudart-hash.so.11.0
start = line.index("/")
path = line[start:].strip()
filename = path.split("/")[-1]
assert filename.rpartition(".so")[0].startswith(
lib_name
), f"Unexpected filename: {filename} for library {lib_name}"
return path


class CudaRTLibrary:
exported_functions = [
# ​cudaError_t cudaSetDevice ( int device )
Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
# cudaError_t cudaDeviceSynchronize ( void )
Function("cudaDeviceSynchronize", cudaError_t, []),
# ​cudaError_t cudaDeviceReset ( void )
Function("cudaDeviceReset", cudaError_t, []),
# const char* cudaGetErrorString ( cudaError_t error )
Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
# ​cudaError_t cudaMalloc ( void** devPtr, size_t size )
Function(
"cudaMalloc",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t],
),
# ​cudaError_t cudaFree ( void* devPtr )
Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
# ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
Function(
"cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]
),
# ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
Function(
"cudaMemcpy",
cudaError_t,
[ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind],
),
# cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
Function(
"cudaIpcGetMemHandle",
cudaError_t,
[ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p],
),
# ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
Function(
"cudaIpcOpenMemHandle",
cudaError_t,
[ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint],
),
]

# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: Dict[str, Any] = {}

# class attribute to store the mapping from library path
# to the corresponding dictionary
path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}

def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
assert so_file is not None, "libcudart is not loaded in the current process"
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
self.lib = CudaRTLibrary.path_to_library_cache[so_file]

if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f
CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]

def CUDART_CHECK(self, result: cudaError_t) -> None:
if result != 0:
error_str = self.cudaGetErrorString(result)
raise RuntimeError(f"CUDART error: {error_str}")

def cudaGetErrorString(self, error: cudaError_t) -> str:
return self.funcs["cudaGetErrorString"](error).decode("utf-8")

def cudaSetDevice(self, device: int) -> None:
self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))

def cudaDeviceSynchronize(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())

def cudaDeviceReset(self) -> None:
self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())

def cudaMalloc(self, size: int) -> ctypes.c_void_p:
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
return devPtr

def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))

def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None:
self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))

def cudaMemcpy(
self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int
) -> None:
cudaMemcpyDefault = 4
kind = cudaMemcpyDefault
self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))

def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
handle = cudaIpcMemHandle_t()
self.CUDART_CHECK(
self.funcs["cudaIpcGetMemHandle"](ctypes.byref(handle), devPtr)
)
return handle

def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
cudaIpcMemLazyEnablePeerAccess = 1
devPtr = ctypes.c_void_p()
self.CUDART_CHECK(
self.funcs["cudaIpcOpenMemHandle"](
ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess
)
)
return devPtr
Loading