diff --git a/python/sglang/backend/openai.py b/python/sglang/backend/openai.py index 6cad2f6aa1..f2dd2f0678 100644 --- a/python/sglang/backend/openai.py +++ b/python/sglang/backend/openai.py @@ -9,8 +9,9 @@ from sglang.lang.ir import SglSamplingParams try: - import openai import tiktoken + + import openai except ImportError as e: openai = tiktoken = e diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 6468c2d5ff..04fd5ffaf7 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -6,7 +6,6 @@ from concurrent.futures import ThreadPoolExecutor from typing import List -import numpy as np import rpyc import torch from rpyc.utils.classic import obtain @@ -36,8 +35,8 @@ logger = logging.getLogger("model_rpc") -class ModelRpcServer(rpyc.Service): - def exposed_init_model( +class ModelRpcServer: + def __init__( self, tp_rank: int, server_args: ServerArgs, @@ -608,14 +607,19 @@ def handle_finished_requests(self, batch: Batch): batch.reqs = [] +class ModelRpcService(rpyc.Service): + exposed_ModelRpcServer = ModelRpcServer + + class ModelRpcClient: def __init__(self, server_args: ServerArgs, port_args: PortArgs): tp_size = server_args.tp_size if tp_size == 1: # Init model - self.model_server = ModelRpcServer() - self.model_server.exposed_init_model(0, server_args, port_args) + self.model_server = ModelRpcService().exposed_ModelRpcServer( + 0, server_args, port_args + ) # Wrap functions def async_wrap(f): @@ -629,14 +633,16 @@ async def _func(*args, **kwargs): with ThreadPoolExecutor(tp_size) as executor: # Launch model processes rets = executor.map(start_model_process, port_args.model_rpc_ports) - self.model_servers = [x[0] for x in rets] + self.remote_services = [x[0] for x in rets] self.procs = [x[1] for x in rets] # Init model def init_model(i): - return self.model_servers[i].init_model(i, server_args, port_args) + return self.remote_services[i].ModelRpcServer( + i, server_args, port_args + ) - rets = [obtain(x) for x in executor.map(init_model, range(tp_size))] + self.model_servers = executor.map(init_model, range(tp_size)) # Wrap functions def async_wrap(func_name): @@ -654,7 +660,7 @@ async def _func(*args, **kwargs): def _init_service(port): t = ThreadedServer( - ModelRpcServer(), + ModelRpcService(), port=port, protocol_config={"allow_pickle": True, "sync_request_timeout": 1800}, ) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index afbe03abe6..f6d9adc3fb 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -1,10 +1,10 @@ import importlib -import logging +import importlib.resources import inspect +import logging +import pkgutil from dataclasses import dataclass from functools import lru_cache -from pathlib import Path -import importlib.resources import numpy as np import torch @@ -18,11 +18,6 @@ from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel -import importlib -import pkgutil - -import sglang - QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig} logger = logging.getLogger("model_runner") @@ -37,7 +32,7 @@ def import_model_classes(): model_arch_name_to_cls = {} package_name = "sglang.srt.models" package = importlib.import_module(package_name) - for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'): + for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): if not ispkg: module = importlib.import_module(name) if hasattr(module, "EntryClass"): @@ -144,9 +139,12 @@ def init_flashinfer_args(self, tp_size): # flashinfer >= 0.0.3 # FIXME: Drop this when flashinfer updates to 0.0.4 - if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7: + if ( + len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) + == 7 + ): args.append(self.model_runner.model_config.head_dim) - + self.prefill_wrapper.begin_forward(*args) else: self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( @@ -307,9 +305,11 @@ def load_model(self): hf_quant_method = hf_quant_config["quant_method"] # compat: autogptq uses is_marlin_format within quant config - if (hf_quant_method == "gptq" - and "is_marlin_format" in hf_quant_config - and hf_quant_config["is_marlin_format"]): + if ( + hf_quant_method == "gptq" + and "is_marlin_format" in hf_quant_config + and hf_quant_config["is_marlin_format"] + ): hf_quant_method = "marlin" quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)