Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic model class loading #101

Merged
merged 2 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
[project.optional-dependencies]
srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
"zmq", "vllm>=0.2.5", "interegular", "lark", "numba",
"pydantic", "diskcache", "cloudpickle"]
"pydantic", "diskcache", "cloudpickle", "pillow"]
openai = ["openai>=1.0", "numpy"]
anthropic = ["anthropic", "numpy"]
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
Expand Down
58 changes: 31 additions & 27 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import importlib
import logging
from dataclasses import dataclass
from enum import Enum, auto
from functools import lru_cache
from pathlib import Path
from typing import List

import numpy as np
import torch
import sglang
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import is_multimodal_model
Expand All @@ -20,6 +23,32 @@
global_model_mode: List[str] = []


@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
for module_path in (Path(sglang.__file__).parent / "srt" / "models").glob("*.py"):
module = importlib.import_module(f"sglang.srt.models.{module_path.stem}")
if hasattr(module, "EntryClass"):
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
return model_arch_name_to_cls


def get_model_cls_by_arch_name(model_arch_names):
model_arch_name_to_cls = import_model_classes()

model_class = None
for arch in model_arch_names:
if arch in model_arch_name_to_cls:
model_class = model_arch_name_to_cls[arch]
break
else:
raise ValueError(
f"Unsupported architectures: {arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_class


@dataclass
class InputMetadata:
model_runner: "ModelRunner"
Expand Down Expand Up @@ -237,34 +266,9 @@ def __init__(

def load_model(self):
"""See also vllm/model_executor/model_loader.py::get_model"""
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llava import LlavaLlamaForCausalLM
from sglang.srt.models.mixtral import MixtralForCausalLM
from sglang.srt.models.qwen import QWenLMHeadModel

# Select model class
architectures = getattr(self.model_config.hf_config, "architectures", [])

model_class = None
for arch in architectures:
if arch == "LlamaForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "MistralForCausalLM":
model_class = LlamaForCausalLM
break
if arch == "LlavaLlamaForCausalLM":
model_class = LlavaLlamaForCausalLM
break
if arch == "MixtralForCausalLM":
model_class = MixtralForCausalLM
break
if arch == "QWenLMHeadModel":
model_class = QWenLMHeadModel
break
if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}")

model_class = get_model_cls_by_arch_name(architectures)
logger.info(f"Rank {self.tp_rank}: load weight begin.")

# Load weights
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,5 @@ def load_weights(
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

EntryClass = LlamaForCausalLM
2 changes: 2 additions & 0 deletions python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
"forward",
clip_vision_embed_forward,
)

EntryClass = LlavaLlamaForCausalLM
2 changes: 2 additions & 0 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,5 @@ def load_weights(
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

EntryClass = MixtralForCausalLM
2 changes: 2 additions & 0 deletions python/sglang/srt/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,5 @@ def load_weights(
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

EntryClass = QWenLMHeadModel