From 95746797bc92709cb21b171f1b986c55216c570e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 06:52:41 -0700 Subject: [PATCH 1/9] use model loader from vllm --- .../sglang/srt/managers/router/model_rpc.py | 1 + .../srt/managers/router/model_runner.py | 134 +++++++----------- python/sglang/srt/model_config.py | 2 +- python/sglang/srt/models/llama2.py | 22 ++- 4 files changed, 64 insertions(+), 95 deletions(-) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index fcfdd0cb04..d5a029f515 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -41,6 +41,7 @@ logger = logging.getLogger("model_rpc") vllm_default_logger.setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.WARN) +logging.getLogger("vllm.selector").setLevel(logging.WARN) class ModelRpcServer: diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 02cc747144..f76c8ccb96 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -1,29 +1,23 @@ import importlib import importlib.resources -import inspect import logging import pkgutil from dataclasses import dataclass from functools import lru_cache -from typing import List +from typing import List, Optional, Type import numpy as np import torch +import torch.nn as nn +from vllm.config import ModelConfig, DeviceConfig, LoadConfig from vllm.distributed import initialize_model_parallel -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models import ModelRegistry from sglang.srt.managers.router.infer_batch import Batch, ForwardMode from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model -QUANTIZATION_CONFIG_MAPPING = { - "awq": AWQConfig, - "gptq": GPTQConfig, - "marlin": MarlinConfig, -} logger = logging.getLogger("model_runner") @@ -31,35 +25,6 @@ global_server_args_dict = {} -@lru_cache() -def import_model_classes(): - model_arch_name_to_cls = {} - package_name = "sglang.srt.models" - package = importlib.import_module(package_name) - for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): - if not ispkg: - module = importlib.import_module(name) - if hasattr(module, "EntryClass"): - model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass - return model_arch_name_to_cls - - -def get_model_cls_by_arch_name(model_arch_names): - model_arch_name_to_cls = import_model_classes() - - model_class = None - for arch in model_arch_names: - if arch in model_arch_name_to_cls: - model_class = model_arch_name_to_cls[arch] - break - else: - raise ValueError( - f"Unsupported architectures: {arch}. " - f"Supported list: {list(model_arch_name_to_cls.keys())}" - ) - return model_class - - @dataclass class InputMetadata: model_runner: "ModelRunner" @@ -287,49 +252,29 @@ def __init__( self.is_multimodal_model = is_multimodal_model(self.model_config) def load_model(self): - """See also vllm/model_executor/model_loader.py::get_model""" - # Select model class - architectures = getattr(self.model_config.hf_config, "architectures", []) - model_class = get_model_cls_by_arch_name(architectures) logger.info(f"Rank {self.tp_rank}: load weight begin.") - # Load weights - quant_config = None - - quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None) - if quant_cfg is not None: - quant_method = quant_cfg.get("quant_method", "").lower() - # compat: autogptq >=0.8.0 use checkpoint_format: str - # compat: autogptq <=0.7.1 is_marlin_format: bool - is_format_marlin = quant_cfg.get( - "checkpoint_format" - ) == "marlin" or quant_cfg.get("is_marlin_format", False) - - # Use marlin if the GPTQ model is serialized in marlin format. - if quant_method == "gptq" and is_format_marlin: - quant_method = "marlin" - - quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method) - - if quant_config_class is None: - raise ValueError(f"Unsupported quantization method: {quant_method}") - - quant_config = quant_config_class.from_config(quant_cfg) - logger.info(f"quant_config: {quant_config}") - - with set_default_torch_dtype(torch.float16): - with torch.device("cuda"): - model = model_class( - config=self.model_config.hf_config, quant_config=quant_config - ) - model.load_weights( - self.model_config.path, - cache_dir=None, - load_format=self.load_format, - revision=None, - ) - self.model = model.eval() - + device_config = DeviceConfig() + load_config = LoadConfig() + model_config = ModelConfig( + model=self.model_config.path, + tokenizer=None, + tokenizer_mode=None, + trust_remote_code=self.model_config.trust_remote_code, + dtype="auto", + seed=42, + revision=self.model_config.revision, + skip_tokenizer_init=True, + ) + self.model = get_model( + model_config=model_config, + device_config=device_config, + load_config=load_config, + lora_config=None, + vision_language_config=None, + parallel_config=None, + scheduler_config=None, + ) logger.info(f"Rank {self.tp_rank}: load weight end.") def profile_max_num_token(self, total_gpu_memory): @@ -455,3 +400,30 @@ def forward(self, batch: Batch, forward_mode: ForwardMode): return self.forward_prefill(batch) else: raise ValueError(f"Invaid forward mode: {forward_mode}") + + +@lru_cache() +def import_model_classes(): + model_arch_name_to_cls = {} + package_name = "sglang.srt.models" + package = importlib.import_module(package_name) + for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): + if not ispkg: + module = importlib.import_module(name) + if hasattr(module, "EntryClass"): + model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass + return model_arch_name_to_cls + + +def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]: + model_arch_name_to_cls = import_model_classes() + if model_arch not in model_arch_name_to_cls: + raise ValueError( + f"Unsupported architectures: {model_arch}. " + f"Supported list: {list(model_arch_name_to_cls.keys())}" + ) + return model_arch_name_to_cls[model_arch] + + +# Monkey patch model loader +setattr(ModelRegistry, "load_model_cls", load_model_cls_srt) \ No newline at end of file diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index af4a6c103d..e675697f0b 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -44,4 +44,4 @@ def __init__( self.num_key_value_heads = self.num_attention_heads self.hidden_size = self.hf_config.hidden_size self.num_hidden_layers = self.hf_config.num_hidden_layers - self.vocab_size = self.hf_config.vocab_size + self.vocab_size = self.hf_config.vocab_size \ No newline at end of file diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index fde8ebb064..cf292eeb10 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -1,7 +1,7 @@ # Adapted from -# https://github.com/vllm-project/vllm/blob/671af2b1c0b3ed6d856d37c21a561cc429a10701/vllm/model_executor/models/llama.py#L1 +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 """Inference-only LLaMA model compatible with HuggingFace weights.""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Iterable import torch from torch import nn @@ -20,11 +20,11 @@ ParallelLMHead, VocabParallelEmbedding, ) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class LlamaMLP(nn.Module): @@ -152,6 +152,10 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, @@ -270,13 +274,7 @@ def forward( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -286,9 +284,7 @@ def load_weights( ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: From 458cf1c8efeb6f27ac29aa3f4fecdeef693a6656 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 07:42:55 -0700 Subject: [PATCH 2/9] fix some --- .../srt/managers/router/model_runner.py | 2 +- python/sglang/srt/models/commandr.py | 19 +- python/sglang/srt/models/dbrx.py | 20 +- python/sglang/srt/models/dbrx_config.py | 281 ------------------ python/sglang/srt/models/llava.py | 21 +- 5 files changed, 20 insertions(+), 323 deletions(-) delete mode 100644 python/sglang/srt/models/dbrx_config.py diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index f76c8ccb96..33882fad72 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -261,7 +261,7 @@ def load_model(self): tokenizer=None, tokenizer_mode=None, trust_remote_code=self.model_config.trust_remote_code, - dtype="auto", + dtype=torch.float16, seed=42, revision=self.model_config.revision, skip_tokenizer_init=True, diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index b485db264b..ab685ed94c 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -18,9 +18,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1 + # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import Optional, Tuple +from typing import Optional, Tuple, Iterable import torch import torch.utils.checkpoint @@ -41,11 +44,11 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator @torch.compile @@ -324,13 +327,7 @@ def forward( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -341,9 +338,7 @@ def load_weights( ] params_dict = dict(self.named_parameters()) loaded_params = set() - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + for name, loaded_weight in weights: for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index e4bce189d3..6b435bd561 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -1,7 +1,7 @@ # Adapted from: -# https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/model_executor/models/dbrx.py +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1 # coding=utf-8 -from typing import Optional +from typing import Iterable, Optional, Tuple import torch import torch.nn as nn @@ -24,12 +24,12 @@ VocabParallelEmbedding, ) from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.transformers_utils.configs.dbrx import DbrxConfig from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.models.dbrx_config import DbrxConfig -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class DbrxRouter(nn.Module): @@ -377,13 +377,7 @@ def forward( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ ( "ws" if weight_name in ["w1", "v1"] else "w2s", @@ -392,9 +386,7 @@ def load_weights( for weight_name in ["w1", "v1", "w2"] ] params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + for name, loaded_weight in weights: for param_name, weight_name in expert_params_mapping: if weight_name not in name: continue diff --git a/python/sglang/srt/models/dbrx_config.py b/python/sglang/srt/models/dbrx_config.py deleted file mode 100644 index 7fb062eb16..0000000000 --- a/python/sglang/srt/models/dbrx_config.py +++ /dev/null @@ -1,281 +0,0 @@ -# Adapted from: -# https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/transformers_utils/configs/dbrx.py -# yapf: disable -# ruff: noqa: E501 -# coding=utf-8 -# Copied from -# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py -"""Dbrx configuration.""" - -# FIXME: remove this once vllm releases a new version - -from typing import Any, Optional - -from transformers.configuration_utils import PretrainedConfig -from transformers.utils import logging - -logger = logging.get_logger(__name__) - -DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class DbrxAttentionConfig(PretrainedConfig): - """Configuration class for Dbrx Attention. - - [`DbrxAttention`] class. It is used to instantiate attention layers - according to the specified arguments, defining the layers architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - attn_pdrop (`float`, *optional*, defaults to 0.0): - The dropout probability for the attention layers. - clip_qkv (`float`, *optional*, defaults to None): - If not `None`, clip the queries, keys, and values in the attention layer to this value. - kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. - rope_theta (float): The base frequency for rope. - """ - - def __init__( - self, - attn_pdrop: float = 0, - clip_qkv: Optional[float] = None, - kv_n_heads: int = 1, - rope_theta: float = 10000.0, - **kwargs: Any, - ): - super().__init__(**kwargs) - self.attn_pdrop = attn_pdrop - self.clip_qkv = clip_qkv - self.kv_n_heads = kv_n_heads - self.rope_theta = rope_theta - - for k in ["model_type"]: - if k in kwargs: - kwargs.pop(k) - if len(kwargs) != 0: - raise ValueError(f"Found unknown {kwargs=}") - - @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path: str, **kwargs: Any - ) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs - ) - - if config_dict.get("model_type") == "dbrx": - config_dict = config_dict["attn_config"] - - if ( - "model_type" in config_dict - and hasattr(cls, "model_type") - and config_dict["model_type"] != cls.model_type - ): - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - - return cls.from_dict(config_dict, **kwargs) - - -class DbrxFFNConfig(PretrainedConfig): - """Configuration class for Dbrx FFN. - - [`DbrxFFN`] class. It is used to instantiate feedforward layers according to - the specified arguments, defining the layers architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - Args: - ffn_act_fn (dict, optional): A dict specifying activation function for the FFN. - The dict should have a key 'name' with the value being the name of - the activation function along with any additional keyword arguments. - ffn_hidden_size (int, optional): The hidden size of the feedforward network. - moe_num_experts (int, optional): The number of experts in the mixture of experts layer. - moe_top_k (int, optional): The number of experts to use in the mixture of experts layer. - moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer. - moe_loss_weight (float, optional): The loss weight for the mixture of experts layer. - moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights. - uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment. - This should only be used for benchmarking purposes. - """ - - def __init__( - self, - ffn_act_fn: Optional[dict] = None, - ffn_hidden_size: int = 3584, - moe_num_experts: int = 4, - moe_top_k: int = 1, - moe_jitter_eps: Optional[float] = None, - moe_loss_weight: float = 0.01, - moe_normalize_expert_weights: Optional[float] = 1, - uniform_expert_assignment: bool = False, - **kwargs: Any, - ): - super().__init__() - if ffn_act_fn is None: - ffn_act_fn = {"name": "silu"} - self.ffn_act_fn = ffn_act_fn - self.ffn_hidden_size = ffn_hidden_size - self.moe_num_experts = moe_num_experts - self.moe_top_k = moe_top_k - self.moe_jitter_eps = moe_jitter_eps - self.moe_loss_weight = moe_loss_weight - self.moe_normalize_expert_weights = moe_normalize_expert_weights - self.uniform_expert_assignment = uniform_expert_assignment - - for k in ["model_type"]: - if k in kwargs: - kwargs.pop(k) - if len(kwargs) != 0: - raise ValueError(f"Found unknown {kwargs=}") - - @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path: str, **kwargs: Any - ) -> "PretrainedConfig": - cls._set_token_in_kwargs(kwargs) - - config_dict, kwargs = cls.get_config_dict( - pretrained_model_name_or_path, **kwargs - ) - - if config_dict.get("model_type") == "dbrx": - config_dict = config_dict["ffn_config"] - - if ( - "model_type" in config_dict - and hasattr(cls, "model_type") - and config_dict["model_type"] != cls.model_type - ): - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - - return cls.from_dict(config_dict, **kwargs) - - -class DbrxConfig(PretrainedConfig): - """Configuration class for Dbrx. - - [`DbrxModel`]. It is used to instantiate a Dbrx model according to the - specified arguments, defining the model architecture. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - d_model (`int`, *optional*, defaults to 6144): - Dimensionality of the embeddings and hidden states. - n_heads (`int`, *optional*, defaults to 48): - Number of attention heads for each attention layer in the Transformer encoder. - n_layers (`int`, *optional*, defaults to 40): - Number of hidden layers in the Transformer encoder. - max_seq_len (`int`, *optional*, defaults to 32768): - The maximum sequence length of the model. - vocab_size (`int`, *optional*, defaults to 100352): - Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by - the `inputs_ids` passed when calling [`DbrxModel`]. - resid_pdrop (`float`, *optional*, defaults to 0.0): - The dropout probability applied to the attention output before combining with residual. - emb_pdrop (`float`, *optional*, defaults to 0.0): - The dropout probability for the embedding layer. - attn_config (`dict`, *optional*): - A dictionary used to configure the model's attention module. - ffn_config (`dict`, *optional*): - A dictionary used to configure the model's FFN module. - use_cache (`bool`, *optional*, defaults to `False`): - Whether or not the model should return the last key/values attentions (not used by all models). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - output_router_logits (`bool`, *optional*, defaults to `False`): - Whether or not the router logits should be returned by the model. Enabling this will also - allow the model to output the auxiliary loss. See [here]() for more details - router_aux_loss_coef (`float`, *optional*, defaults to 0.001): - The aux loss factor for the total loss. - - - Example: - ```python - >>> from transformers import DbrxConfig, DbrxModel - - >>> # Initializing a Dbrx configuration - >>> configuration = DbrxConfig() - - >>> # Initializing a model (with random weights) from the configuration - >>> model = DbrxModel(configuration) - - >>> # Accessing the model configuration - >>> configuration = model.config - ``` - """ - - model_type = "dbrx" - attribute_map = { - "num_attention_heads": "n_heads", - "hidden_size": "d_model", - "num_hidden_layers": "n_layers", - "max_position_embeddings": "max_seq_len", - } - - def __init__( - self, - d_model: int = 2048, - n_heads: int = 16, - n_layers: int = 24, - max_seq_len: int = 2048, - vocab_size: int = 32000, - resid_pdrop: float = 0.0, - emb_pdrop: float = 0.0, - attn_config: Optional[DbrxAttentionConfig] = None, - ffn_config: Optional[DbrxFFNConfig] = None, - use_cache: bool = True, - initializer_range: float = 0.02, - output_router_logits: bool = False, - router_aux_loss_coef: float = 0.05, - **kwargs: Any, - ): - if attn_config is None: - self.attn_config = DbrxAttentionConfig() - elif isinstance(attn_config, dict): - self.attn_config = DbrxAttentionConfig(**attn_config) - else: - self.attn_config = attn_config - - if ffn_config is None: - self.ffn_config = DbrxFFNConfig() - elif isinstance(ffn_config, dict): - self.ffn_config = DbrxFFNConfig(**ffn_config) - else: - self.ffn_config = ffn_config - - self.d_model = d_model - self.n_heads = n_heads - self.n_layers = n_layers - self.max_seq_len = max_seq_len - self.vocab_size = vocab_size - self.resid_pdrop = resid_pdrop - self.emb_pdrop = emb_pdrop - self.use_cache = use_cache - self.initializer_range = initializer_range - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - - tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) - if tie_word_embeddings: - raise ValueError( - "tie_word_embeddings is not supported for Dbrx models." - ) - - super().__init__( - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index abce92061a..5d4726b548 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -1,6 +1,6 @@ """Inference-only LLaVa model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import List, Iterable, Optional, Tuple import numpy as np import torch @@ -8,6 +8,7 @@ from transformers import CLIPVisionModel, LlavaConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.model_runner import InputMetadata @@ -17,7 +18,6 @@ unpad_image_shape, ) from sglang.srt.models.llama2 import LlamaForCausalLM -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class LlavaLlamaForCausalLM(nn.Module): @@ -233,13 +233,7 @@ def forward( elif input_metadata.forward_mode == ForwardMode.DECODE: return self.language_model(input_ids, positions, input_metadata) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # load clip vision model by cfg['mm_vision_tower']: # huggingface_name or path_of_clip_relative_to_llava_model_dir vision_path = self.config.mm_vision_tower @@ -272,9 +266,8 @@ def load_weights( "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). } params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + weights = list(weights) + for name, loaded_weight in weights: # FIXME: why projector weights read two times? if "projector" in name or "vision_tower" in name: for weight_name, param_name in projector_weights.items(): @@ -285,9 +278,7 @@ def load_weights( weight_loader(param, loaded_weight) # load language model - self.language_model.load_weights( - model_name_or_path, cache_dir, load_format, revision - ) + self.language_model.load_weights(weights) monkey_path_clip_vision_embed_forward() From 75a4d13c1c78ff491a40ceefcb7542048d836690 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 07:44:34 -0700 Subject: [PATCH 3/9] fix gemma --- python/sglang/srt/models/gemma.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 712af09815..8ad77f12a9 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -1,7 +1,7 @@ # Adapted from: -# https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1 """Inference-only Gemma model compatible with HuggingFace weights.""" -from typing import Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn @@ -18,11 +18,11 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class GemmaMLP(nn.Module): @@ -285,13 +285,7 @@ def forward( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -302,9 +296,7 @@ def load_weights( ] params_dict = dict(self.named_parameters()) loaded_params = set() - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + for name, loaded_weight in weights: for param_name, shard_name, shard_id in stacked_params_mapping: if shard_name not in name: continue From 8a79861462f9ca6db917f6cdcab733db2d889f1e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 07:53:34 -0700 Subject: [PATCH 4/9] fix more text models --- python/sglang/srt/models/mixtral.py | 25 +++++++------------------ python/sglang/srt/models/qwen.py | 18 ++++++------------ python/sglang/srt/models/qwen2.py | 16 ++++------------ python/sglang/srt/models/stablelm.py | 20 ++++++-------------- 4 files changed, 23 insertions(+), 56 deletions(-) diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 48e8d37fc2..94f0ed393e 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -1,7 +1,7 @@ # Adapted from -# https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1 +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1 """Inference-only Mixtral model.""" -from typing import Optional +from typing import Iterable, Optional, Tuple import numpy as np import torch @@ -25,11 +25,12 @@ ParallelLMHead, VocabParallelEmbedding, ) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class MixtralMLP(nn.Module): @@ -107,7 +108,7 @@ def __init__( ] ) self.gate = ReplicatedLinear( - config.hidden_size, self.num_total_experts, bias=False, linear_method=None + config.hidden_size, self.num_total_experts, bias=False, quant_config=None ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -333,13 +334,7 @@ def forward( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -348,13 +343,7 @@ def load_weights( ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, - cache_dir, - load_format, - revision, - fall_back_to_pt=False, - ): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index a242d013bd..9b4da3c361 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, Optional +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1 +from typing import Any, Dict, Optional, Iterable, Tuple import torch from torch import nn @@ -17,11 +19,11 @@ ParallelLMHead, VocabParallelEmbedding, ) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class QWenMLP(nn.Module): @@ -245,22 +247,14 @@ def forward( ) return next_tokens - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w2", 0), ("gate_up_proj", "w1", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index f0ad5d9bf7..843d91a94f 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -1,7 +1,7 @@ # Adapted from llama2.py # Modify details for the adaptation of Qwen2 model. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Iterable import torch from torch import nn @@ -19,11 +19,11 @@ ParallelLMHead, VocabParallelEmbedding, ) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator Qwen2Config = None @@ -271,13 +271,7 @@ def forward( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -287,9 +281,7 @@ def load_weights( ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 423e603cd5..5850deb261 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -1,8 +1,8 @@ -# This code is based on: -# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/stablelm.py +# Adapted from: +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1 """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b) model compatible with HuggingFace weights.""" -from typing import Optional, Tuple +from typing import Optional, Tuple, Iterable import torch from torch import nn @@ -20,11 +20,11 @@ ParallelLMHead, VocabParallelEmbedding, ) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class StablelmMLP(nn.Module): @@ -245,13 +245,7 @@ def forward( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -261,9 +255,7 @@ def load_weights( ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: From 9df6d68ac831af8e76cef3d48bfed701df4410f1 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 08:21:00 -0700 Subject: [PATCH 5/9] add hint --- python/sglang/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index d1fa241e92..332551aa0a 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -141,7 +141,7 @@ def encode_frame(frame): def encode_video_base64(video_path, num_frames=16): - import cv2 + import cv2 # pip install opencv-python-headless cap = cv2.VideoCapture(video_path) if not cap.isOpened(): From c696359ffe490e4ef09fa1d8e74126b613603c3e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 08:22:04 -0700 Subject: [PATCH 6/9] fix yivl --- python/sglang/srt/models/yivl.py | 34 +++++++++++--------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/models/yivl.py b/python/sglang/srt/models/yivl.py index 3b1b99c8d6..0d5d70bc7d 100644 --- a/python/sglang/srt/models/yivl.py +++ b/python/sglang/srt/models/yivl.py @@ -1,40 +1,33 @@ """Inference-only Yi-VL model.""" -import os -from typing import List, Optional +from typing import Tuple, Iterable import torch import torch.nn as nn from transformers import CLIPVisionModel, LlavaConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llava import ( LlavaLlamaForCausalLM, - clip_vision_embed_forward, monkey_path_clip_vision_embed_forward, ) -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class YiVLForCausalLM(LlavaLlamaForCausalLM): - def __init__(self, *args, **kwargs): - self.config = kwargs["config"] - super().__init__(self.config) + def __init__( + self, config, quant_config = None, + ) -> None: + super().__init__(config, quant_config) self.multi_modal_projector = YiVLMultiModalProjector(self.config) self.vision_tower_subfolder = self.config.mm_vision_tower.replace( "./", "" ) # Everything after "./" - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B) self.vision_tower = CLIPVisionModel.from_pretrained( - model_name_or_path, + self.config._name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder, ).cuda() @@ -68,9 +61,8 @@ def load_weights( "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). } params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + weights = list(weights) + for name, loaded_weight in weights: if "projector" in name or "vision_tower" in name: for weight_name, param_name in projector_weights.items(): if weight_name in name: @@ -80,9 +72,7 @@ def load_weights( weight_loader(param, loaded_weight) # load language model - self.language_model.load_weights( - model_name_or_path, cache_dir, load_format, revision - ) + self.language_model.load_weights(weights) monkey_path_clip_vision_embed_forward() @@ -103,7 +93,7 @@ def __init__(self, config: LlavaConfig): def forward(self, image_features): hidden_states = self.linear_1(image_features) - hidden_state = self.ln_1(hidden_states) + hidden_states = self.ln_1(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) hidden_states = self.ln_2(hidden_states) From 3ec2e50c10ba1ac2933dcef760ce9ef98398a7f9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 08:23:52 -0700 Subject: [PATCH 7/9] update --- examples/quick_start/srt_example_yi_vl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/quick_start/srt_example_yi_vl.py b/examples/quick_start/srt_example_yi_vl.py index e4f6ef16db..359aacac31 100644 --- a/examples/quick_start/srt_example_yi_vl.py +++ b/examples/quick_start/srt_example_yi_vl.py @@ -1,5 +1,7 @@ """ Usage: python3 srt_example_yi_vl.py + +Requirements: transformers==4.38 """ import sglang as sgl From 7bb6972ab2f0407837bb82935d059a5013a436da Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 09:00:48 -0700 Subject: [PATCH 8/9] Fix llava vid --- .../srt/managers/router/model_runner.py | 10 +- python/sglang/srt/model_config.py | 7 +- python/sglang/srt/models/llavavid.py | 25 +- python/sglang/srt/weight_utils.py | 417 ------------------ 4 files changed, 17 insertions(+), 442 deletions(-) delete mode 100644 python/sglang/srt/weight_utils.py diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 33882fad72..a74b1d10fd 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -9,7 +9,8 @@ import numpy as np import torch import torch.nn as nn -from vllm.config import ModelConfig, DeviceConfig, LoadConfig +from vllm.config import DeviceConfig, LoadConfig +from vllm.config import ModelConfig as VllmModelConfig from vllm.distributed import initialize_model_parallel from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry @@ -256,7 +257,7 @@ def load_model(self): device_config = DeviceConfig() load_config = LoadConfig() - model_config = ModelConfig( + vllm_model_config = VllmModelConfig( model=self.model_config.path, tokenizer=None, tokenizer_mode=None, @@ -266,8 +267,11 @@ def load_model(self): revision=self.model_config.revision, skip_tokenizer_init=True, ) + if self.model_config.model_overide_args is not None: + vllm_model_config.hf_config.update(self.model_config.model_overide_args) + self.model = get_model( - model_config=model_config, + model_config=vllm_model_config, device_config=device_config, load_config=load_config, lora_config=None, diff --git a/python/sglang/srt/model_config.py b/python/sglang/srt/model_config.py index e675697f0b..dfeac0a255 100644 --- a/python/sglang/srt/model_config.py +++ b/python/sglang/srt/model_config.py @@ -15,10 +15,9 @@ def __init__( self.path = path self.trust_remote_code = trust_remote_code self.revision = revision - self.hf_config = get_config(self.path, trust_remote_code, revision) - - if model_overide_args is not None: - self.hf_config.update(model_overide_args) + self.model_overide_args = model_overide_args + self.hf_config = get_config(self.path, trust_remote_code, revision, + model_overide_args=model_overide_args) if context_length is not None: self.context_len = context_length diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index e4205e509d..0afc3f0d6b 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -1,14 +1,14 @@ """Inference-only LLaVa video model compatible with HuggingFace weights.""" -import os -from typing import List, Optional +from typing import List, Iterable, Optional, Tuple import numpy as np import torch from torch import nn -from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig +from transformers import CLIPVisionModel, LlavaConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.model_runner import InputMetadata @@ -18,7 +18,6 @@ unpad_image_shape, ) from sglang.srt.models.llama2 import LlamaForCausalLM -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class LlavaVidForCausalLM(nn.Module): @@ -65,7 +64,6 @@ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): pad_ids = pad_value * ( (new_image_feature_len + len(pad_value)) // len(pad_value) ) - # print(input_ids) offset = input_ids.index(self.config.image_token_index) # old_len + pad_len - 1, because we need to remove image_token_id new_input_ids = ( @@ -200,13 +198,7 @@ def forward( elif input_metadata.forward_mode == ForwardMode.DECODE: return self.language_model(input_ids, positions, input_metadata) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # load clip vision model by cfg['mm_vision_tower']: # huggingface_name or path_of_clip_relative_to_llava_model_dir vision_path = self.config.mm_vision_tower @@ -244,9 +236,8 @@ def load_weights( "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). } params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + weights = list(weights) + for name, loaded_weight in weights: # FIXME: why projector weights read two times? if "projector" in name or "vision_tower" in name: for weight_name, param_name in projector_weights.items(): @@ -261,9 +252,7 @@ def load_weights( weight_loader(param, loaded_weight) # load language model - self.language_model.load_weights( - model_name_or_path, cache_dir, load_format, revision - ) + self.language_model.load_weights(weights) monkey_path_clip_vision_embed_forward() diff --git a/python/sglang/srt/weight_utils.py b/python/sglang/srt/weight_utils.py deleted file mode 100644 index 1170c6cfe6..0000000000 --- a/python/sglang/srt/weight_utils.py +++ /dev/null @@ -1,417 +0,0 @@ -# The PR(https://github.com/vllm-project/vllm/pull/4097) of vllm borken the sglang code. -# In order to adapt to the latest code without modifying too much code, -# copied the previous vllm/model_executor/weight_utils.py -# Copied in https://github.com/vllm-project/vllm/blob/05434764cd99990035779cf9a4ed86623b528825/vllm/model_executor/weight_utils.py - -"""Utilities for downloading and initializing model weights.""" -import fnmatch -import glob -import hashlib -import json -import os -from collections import defaultdict -from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union - -import filelock -import huggingface_hub.constants -import numpy as np -import torch -from huggingface_hub import HfFileSystem, snapshot_download -from safetensors.torch import load_file, safe_open, save_file -from tqdm.auto import tqdm -from vllm.config import ModelConfig -from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import ( - QuantizationConfig, - get_quantization_config, -) -from vllm.model_executor.layers.quantization.schema import QuantParamSchema - -logger = init_logger(__name__) - -# use system-level temp directory for file locks, so that multiple users -# can share the same lock without error. -# lock files in the temp directory will be automatically deleted when the -# system reboots, so users will not complain about annoying lock files -temp_dir = ( - os.environ.get("TMPDIR") - or os.environ.get("TEMP") - or os.environ.get("TMP") - or "/tmp/" -) - - -def enable_hf_transfer(): - """automatically activates hf_transfer""" - if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: - try: - # enable hf hub transfer if available - import hf_transfer # type: ignore # noqa - - huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True - except ImportError: - pass - - -enable_hf_transfer() - - -class Disabledtqdm(tqdm): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, disable=True) - - -def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): - lock_dir = cache_dir or temp_dir - os.makedirs(os.path.dirname(lock_dir), exist_ok=True) - model_name = model_name_or_path.replace("/", "-") - hash_name = hashlib.sha256(model_name.encode()).hexdigest() - # add hash to avoid conflict with old users' lock files - lock_file_name = hash_name + model_name + ".lock" - # mode 0o666 is required for the filelock to be shared across users - lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666) - return lock - - -def _shared_pointers(tensors): - ptrs = defaultdict(list) - for k, v in tensors.items(): - ptrs[v.data_ptr()].append(k) - failing = [] - for _, names in ptrs.items(): - if len(names) > 1: - failing.append(names) - return failing - - -def convert_bin_to_safetensor_file( - pt_filename: str, - sf_filename: str, -) -> None: - loaded = torch.load(pt_filename, map_location="cpu") - if "state_dict" in loaded: - loaded = loaded["state_dict"] - shared = _shared_pointers(loaded) - for shared_weights in shared: - for name in shared_weights[1:]: - loaded.pop(name) - - # For tensors to be contiguous - loaded = {k: v.contiguous() for k, v in loaded.items()} - - dirname = os.path.dirname(sf_filename) - os.makedirs(dirname, exist_ok=True) - save_file(loaded, sf_filename, metadata={"format": "pt"}) - - # check file size - sf_size = os.stat(sf_filename).st_size - pt_size = os.stat(pt_filename).st_size - if (sf_size - pt_size) / pt_size > 0.01: - raise RuntimeError( - f"""The file size different is more than 1%: - - {sf_filename}: {sf_size} - - {pt_filename}: {pt_size} - """ - ) - - # check if the tensors are the same - reloaded = load_file(sf_filename) - for k in loaded: - pt_tensor = loaded[k] - sf_tensor = reloaded[k] - if not torch.equal(pt_tensor, sf_tensor): - raise RuntimeError(f"The output tensors do not match for key {k}") - - -# TODO(woosuk): Move this to other place. -def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: - quant_cls = get_quantization_config(model_config.quantization) - # Read the quantization config from the HF model config, if available. - hf_quant_config = getattr(model_config.hf_config, "quantization_config", None) - if hf_quant_config is not None: - return quant_cls.from_config(hf_quant_config) - model_name_or_path = model_config.model - is_local = os.path.isdir(model_name_or_path) - if not is_local: - # Download the config files. - with get_lock(model_name_or_path, model_config.download_dir): - hf_folder = snapshot_download( - model_name_or_path, - revision=model_config.revision, - allow_patterns="*.json", - cache_dir=model_config.download_dir, - tqdm_class=Disabledtqdm, - ) - else: - hf_folder = model_name_or_path - config_files = glob.glob(os.path.join(hf_folder, "*.json")) - - quant_config_files = [ - f - for f in config_files - if any(f.endswith(x) for x in quant_cls.get_config_filenames()) - ] - if len(quant_config_files) == 0: - raise ValueError(f"Cannot find the config file for {model_config.quantization}") - if len(quant_config_files) > 1: - raise ValueError( - f"Found multiple config files for {model_config.quantization}: " - f"{quant_config_files}" - ) - - quant_config_file = quant_config_files[0] - with open(quant_config_file, "r") as f: - config = json.load(f) - return quant_cls.from_config(config) - - -def prepare_hf_model_weights( - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - fall_back_to_pt: bool = True, - revision: Optional[str] = None, -) -> Tuple[str, List[str], bool]: - # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) and load_format != "tensorizer" - use_safetensors = False - # Some quantized models use .pt files for storing the weights. - if load_format == "auto": - allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == "safetensors": - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == "pt": - allow_patterns = ["*.pt"] - elif load_format == "npcache": - allow_patterns = ["*.bin"] - elif load_format == "tensorizer": - allow_patterns = ["*.tensors"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - if not is_local and load_format != "tensorizer": - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break - - logger.info(f"Using model weights format {allow_patterns}") - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download( - model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=Disabledtqdm, - revision=revision, - ) - else: - hf_folder = model_name_or_path - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True - break - if not use_safetensors: - # Exclude files that are not needed for inference. - # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 - blacklist = [ - "training_args.bin", - "optimizer.bin", - "optimizer.pt", - "scheduler.pt", - "scaler.pt", - ] - hf_weights_files = [ - f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist) - ] - - if load_format == "tensorizer": - return hf_folder, hf_weights_files, use_safetensors - - if len(hf_weights_files) == 0: - raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`") - - return hf_folder, hf_weights_files, use_safetensors - - -def hf_model_weights_iterator( - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: Union[Tuple, str] = "auto", - revision: Optional[str] = None, - fall_back_to_pt: Optional[bool] = True, -) -> Iterator[Tuple[str, torch.Tensor]]: - hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( - model_name_or_path, - cache_dir=cache_dir, - load_format=load_format, - fall_back_to_pt=fall_back_to_pt, - revision=revision, - ) - - if load_format == "npcache": - # Currently np_cache only support *.bin checkpoints - assert use_safetensors is False - - # Convert the model weights from torch tensors to numpy arrays for - # faster loading. - np_folder = os.path.join(hf_folder, "np") - os.makedirs(np_folder, exist_ok=True) - weight_names_file = os.path.join(np_folder, "weight_names.json") - # Use file lock to prevent multiple processes from - # dumping the same model weights to numpy at the same time. - with get_lock(model_name_or_path, cache_dir): - if not os.path.exists(weight_names_file): - weight_names = [] - for bin_file in hf_weights_files: - state = torch.load(bin_file, map_location="cpu") - for name, param in state.items(): - param_path = os.path.join(np_folder, name) - with open(param_path, "wb") as f: - np.save(f, param.cpu().detach().numpy()) - weight_names.append(name) - with open(weight_names_file, "w") as f: - json.dump(weight_names, f) - - with open(weight_names_file, "r") as f: - weight_names = json.load(f) - - for name in weight_names: - param_path = os.path.join(np_folder, name) - with open(param_path, "rb") as f: - param = np.load(f) - yield name, torch.from_numpy(param) - elif load_format == "tensorizer": - from vllm.model_executor.tensorizer_loader import ( - TensorDeserializer, - open_stream, - tensorizer_warning, - ) - - tensorizer_args = load_format.params - tensorizer_warning( - "Deserializing HuggingFace models is not optimized for " - "loading on vLLM, as tensorizer is forced to load to CPU. " - "Consider deserializing a vLLM model instead for faster " - "load times. See the examples/tensorize_vllm_model.py example " - "script for serializing vLLM models." - ) - - deserializer_args = tensorizer_args.deserializer_params - stream_params = tensorizer_args.stream_params - stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) - with TensorDeserializer(stream, **deserializer_args, device="cpu") as state: - for name, param in state.items(): - yield name, param - del state - elif use_safetensors: - for st_file in hf_weights_files: - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) - yield name, param - else: - for bin_file in hf_weights_files: - state = torch.load(bin_file, map_location="cpu") - for name, param in state.items(): - yield name, param - del state - torch.cuda.empty_cache() - - -def kv_cache_scales_loader( - filename: str, - tp_rank: int, - tp_size: int, - num_hidden_layers: int, - model_type: Optional[str], -) -> Iterable[Tuple[int, float]]: - """ - A simple utility to read in KV cache scaling factors that have been - previously serialized to disk. Used by the model to populate the appropriate - KV cache scaling factors. The serialization should represent a dictionary - whose keys are the TP ranks and values are another dictionary mapping layers - to their KV cache scaling factors. - Keep this function in sync with the output of examples/fp8/extract_scales.py - """ - try: - with open(filename) as f: - context = { - "model_type": model_type, - "num_hidden_layers": num_hidden_layers, - "tp_rank": tp_rank, - "tp_size": tp_size, - } - schema_dct = json.load(f) - schema = QuantParamSchema.model_validate(schema_dct, context=context) - layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] - return layer_scales_map.items() - - except FileNotFoundError: - logger.error(f"File or directory '{filename}' not found.") - except json.JSONDecodeError: - logger.error(f"Error decoding JSON in file '{filename}'.") - except Exception as e: - logger.error(f"An error occurred while reading '{filename}': {e}") - # This section is reached if and only if any of the excepts are hit - # Return an empty iterable (list) => no KV cache scales are loaded - # which ultimately defaults to 1.0 scales - logger.warning( - "Defaulting to KV cache scaling factors = 1.0 " - f"for all layers in TP rank {tp_rank} " - "as an error occurred during loading." - ) - return [] - - -def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: - """convert PySafeSlice object from safetensors to torch.Tensor - - PySafeSlice object supports indexing, which is done before loading the - actual tensor and can reduce the amount of memory being read into the - memory. However, it does not support more advanced functionalities - like `.view()` or `.t()`. Therefore, if we need to modify the loaded - tensor with these more complicated operators, we need to convert to - tensor first. - """ - if not isinstance(x, torch.Tensor): - x = x[:] - return x - - -def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - """Default weight loader.""" - assert param.size() == loaded_weight.size() - param.data.copy_(loaded_weight) - - -def initialize_dummy_weights( - model: torch.nn.Module, - low: float = -1e-3, - high: float = 1e-3, -) -> None: - """Initialize model weights with random values. - - The model weights must be randomly initialized for accurate performance - measurements. Additionally, the model weights should not cause NaNs in the - forward pass. We empirically found that initializing the weights with - values between -1e-3 and 1e-3 works well for most models. - """ - for param in model.state_dict().values(): - if torch.is_floating_point(param): - param.data.uniform_(low, high) From de5c4a49e068928cb97424e6f2df8d669a7240b6 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 21 May 2024 09:01:40 -0700 Subject: [PATCH 9/9] fix llava_mistral/qwen --- python/sglang/srt/models/llava_mistral.py | 21 ++++++--------------- python/sglang/srt/models/llava_qwen.py | 21 ++++++--------------- 2 files changed, 12 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/models/llava_mistral.py b/python/sglang/srt/models/llava_mistral.py index 2a42e2b5e4..10531f84fb 100644 --- a/python/sglang/srt/models/llava_mistral.py +++ b/python/sglang/srt/models/llava_mistral.py @@ -1,6 +1,6 @@ """Inference-only LLaVa model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import List, Iterable, Optional, Tuple import numpy as np import torch @@ -8,6 +8,7 @@ from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.model_runner import InputMetadata @@ -17,7 +18,6 @@ unpad_image_shape, ) from sglang.srt.models.mistral import MistralForCausalLM -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class LlavaMistralForCausalLM(nn.Module): @@ -246,13 +246,7 @@ def forward( elif input_metadata.forward_mode == ForwardMode.DECODE: return self.language_model(input_ids, positions, input_metadata) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # load clip vision model by cfg['mm_vision_tower']: # huggingface_name or path_of_clip_relative_to_llava_model_dir vision_path = self.config.mm_vision_tower @@ -285,9 +279,8 @@ def load_weights( "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). } params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + weights = list(weights) + for name, loaded_weight in weights: # FIXME: why projector weights read two times? if "projector" in name or "vision_tower" in name: for weight_name, param_name in projector_weights.items(): @@ -298,9 +291,7 @@ def load_weights( weight_loader(param, loaded_weight) # load language model - self.language_model.load_weights( - model_name_or_path, cache_dir, load_format, revision - ) + self.language_model.load_weights(weights) monkey_path_clip_vision_embed_forward() diff --git a/python/sglang/srt/models/llava_qwen.py b/python/sglang/srt/models/llava_qwen.py index 2c60c5ef91..b524ef037b 100644 --- a/python/sglang/srt/models/llava_qwen.py +++ b/python/sglang/srt/models/llava_qwen.py @@ -1,6 +1,6 @@ """Inference-only LLaVa model compatible with HuggingFace weights.""" -from typing import List, Optional +from typing import List, Iterable, Optional, Tuple import numpy as np import torch @@ -8,6 +8,7 @@ from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, Qwen2Config from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.managers.router.infer_batch import ForwardMode from sglang.srt.managers.router.model_runner import InputMetadata @@ -17,7 +18,6 @@ unpad_image_shape, ) from sglang.srt.models.qwen2 import Qwen2ForCausalLM -from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator class LlavaQwenForCausalLM(nn.Module): @@ -246,13 +246,7 @@ def forward( elif input_metadata.forward_mode == ForwardMode.DECODE: return self.language_model(input_ids, positions, input_metadata) - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # load clip vision model by cfg['mm_vision_tower']: # huggingface_name or path_of_clip_relative_to_llava_model_dir vision_path = self.config.mm_vision_tower @@ -285,9 +279,8 @@ def load_weights( "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). } params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision - ): + weights = list(weights) + for name, loaded_weight in weights: # FIXME: why projector weights read two times? if "projector" in name or "vision_tower" in name: for weight_name, param_name in projector_weights.items(): @@ -298,9 +291,7 @@ def load_weights( weight_loader(param, loaded_weight) # load language model - self.language_model.load_weights( - model_name_or_path, cache_dir, load_format, revision - ) + self.language_model.load_weights(weights) monkey_path_clip_vision_embed_forward()