Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use model loader from vllm #459

Merged
merged 9 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/quick_start/srt_example_yi_vl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
Usage: python3 srt_example_yi_vl.py

Requirements: transformers==4.38
"""
import sglang as sgl

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
logger = logging.getLogger("model_rpc")
vllm_default_logger.setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.selector").setLevel(logging.WARN)


class ModelRpcServer:
Expand Down
138 changes: 57 additions & 81 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,31 @@
import importlib
import importlib.resources
import inspect
import logging
import pkgutil
from dataclasses import dataclass
from functools import lru_cache
from typing import List
from typing import List, Optional, Type

import numpy as np
import torch
import torch.nn as nn
from vllm.config import DeviceConfig, LoadConfig
from vllm.config import ModelConfig as VllmModelConfig
from vllm.distributed import initialize_model_parallel
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry

from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model

QUANTIZATION_CONFIG_MAPPING = {
"awq": AWQConfig,
"gptq": GPTQConfig,
"marlin": MarlinConfig,
}

logger = logging.getLogger("model_runner")

# for server args in model endpoints
global_server_args_dict = {}


@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
module = importlib.import_module(name)
if hasattr(module, "EntryClass"):
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
return model_arch_name_to_cls


def get_model_cls_by_arch_name(model_arch_names):
model_arch_name_to_cls = import_model_classes()

model_class = None
for arch in model_arch_names:
if arch in model_arch_name_to_cls:
model_class = model_arch_name_to_cls[arch]
break
else:
raise ValueError(
f"Unsupported architectures: {arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_class


@dataclass
class InputMetadata:
model_runner: "ModelRunner"
Expand Down Expand Up @@ -287,49 +253,32 @@ def __init__(
self.is_multimodal_model = is_multimodal_model(self.model_config)

def load_model(self):
"""See also vllm/model_executor/model_loader.py::get_model"""
# Select model class
architectures = getattr(self.model_config.hf_config, "architectures", [])
model_class = get_model_cls_by_arch_name(architectures)
logger.info(f"Rank {self.tp_rank}: load weight begin.")

# Load weights
quant_config = None

quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_format_marlin = quant_cfg.get(
"checkpoint_format"
) == "marlin" or quant_cfg.get("is_marlin_format", False)

# Use marlin if the GPTQ model is serialized in marlin format.
if quant_method == "gptq" and is_format_marlin:
quant_method = "marlin"

quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)

if quant_config_class is None:
raise ValueError(f"Unsupported quantization method: {quant_method}")

quant_config = quant_config_class.from_config(quant_cfg)
logger.info(f"quant_config: {quant_config}")

with set_default_torch_dtype(torch.float16):
with torch.device("cuda"):
model = model_class(
config=self.model_config.hf_config, quant_config=quant_config
)
model.load_weights(
self.model_config.path,
cache_dir=None,
load_format=self.load_format,
revision=None,
)
self.model = model.eval()

device_config = DeviceConfig()
load_config = LoadConfig()
vllm_model_config = VllmModelConfig(
model=self.model_config.path,
tokenizer=None,
tokenizer_mode=None,
trust_remote_code=self.model_config.trust_remote_code,
dtype=torch.float16,
seed=42,
revision=self.model_config.revision,
skip_tokenizer_init=True,
)
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)

self.model = get_model(
model_config=vllm_model_config,
device_config=device_config,
load_config=load_config,
lora_config=None,
vision_language_config=None,
parallel_config=None,
scheduler_config=None,
)
logger.info(f"Rank {self.tp_rank}: load weight end.")

def profile_max_num_token(self, total_gpu_memory):
Expand Down Expand Up @@ -455,3 +404,30 @@ def forward(self, batch: Batch, forward_mode: ForwardMode):
return self.forward_prefill(batch)
else:
raise ValueError(f"Invaid forward mode: {forward_mode}")


@lru_cache()
def import_model_classes():
model_arch_name_to_cls = {}
package_name = "sglang.srt.models"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
module = importlib.import_module(name)
if hasattr(module, "EntryClass"):
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
return model_arch_name_to_cls


def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
model_arch_name_to_cls = import_model_classes()
if model_arch not in model_arch_name_to_cls:
raise ValueError(
f"Unsupported architectures: {model_arch}. "
f"Supported list: {list(model_arch_name_to_cls.keys())}"
)
return model_arch_name_to_cls[model_arch]


# Monkey patch model loader
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
9 changes: 4 additions & 5 deletions python/sglang/srt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ def __init__(
self.path = path
self.trust_remote_code = trust_remote_code
self.revision = revision
self.hf_config = get_config(self.path, trust_remote_code, revision)

if model_overide_args is not None:
self.hf_config.update(model_overide_args)
self.model_overide_args = model_overide_args
self.hf_config = get_config(self.path, trust_remote_code, revision,
model_overide_args=model_overide_args)

if context_length is not None:
self.context_len = context_length
Expand All @@ -44,4 +43,4 @@ def __init__(
self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_config.hidden_size
self.num_hidden_layers = self.hf_config.num_hidden_layers
self.vocab_size = self.hf_config.vocab_size
self.vocab_size = self.hf_config.vocab_size
19 changes: 7 additions & 12 deletions python/sglang/srt/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1

# This file is based on the LLama model definition file in transformers
"""PyTorch Cohere model."""
from typing import Optional, Tuple
from typing import Optional, Tuple, Iterable

import torch
import torch.utils.checkpoint
Expand All @@ -41,11 +44,11 @@
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader

from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator


@torch.compile
Expand Down Expand Up @@ -324,13 +327,7 @@ def forward(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)

def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
Expand All @@ -341,9 +338,7 @@ def load_weights(
]
params_dict = dict(self.named_parameters())
loaded_params = set()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
for name, loaded_weight in weights:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
Expand Down
20 changes: 6 additions & 14 deletions python/sglang/srt/models/dbrx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Adapted from:
# https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/model_executor/models/dbrx.py
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
# coding=utf-8
from typing import Optional
from typing import Iterable, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -24,12 +24,12 @@
VocabParallelEmbedding,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.transformers_utils.configs.dbrx import DbrxConfig

from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.models.dbrx_config import DbrxConfig
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator


class DbrxRouter(nn.Module):
Expand Down Expand Up @@ -377,13 +377,7 @@ def forward(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)

def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [
(
"ws" if weight_name in ["w1", "v1"] else "w2s",
Expand All @@ -392,9 +386,7 @@ def load_weights(
for weight_name in ["w1", "v1", "w2"]
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
for name, loaded_weight in weights:
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
Expand Down
Loading