Skip to content

Commit

Permalink
Update model_loader deps and qqq quantization deps (#2220)
Browse files Browse the repository at this point in the history
  • Loading branch information
HandH1998 authored Dec 2, 2024
1 parent 33deca8 commit 9255020
Show file tree
Hide file tree
Showing 58 changed files with 2,363 additions and 366 deletions.
4 changes: 4 additions & 0 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank):
model_config = ModelConfig(
server_args.model_path,
trust_remote_code=server_args.trust_remote_code,
revision=server_args.revision,
context_length=server_args.context_length,
model_override_args=server_args.json_model_override_args,
is_embedding=server_args.is_embedding,
dtype=server_args.dtype,
quantization=server_args.quantization,
)
model_runner = ModelRunner(
model_config=model_config,
Expand Down
17 changes: 17 additions & 0 deletions python/sglang/srt/configs/device_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import logging
from typing import Optional

import torch

logger = logging.getLogger(__name__)


class DeviceConfig:
device: Optional[torch.device]

def __init__(self, device: str = "cuda") -> None:
if device in ["cuda", "xpu", "hpu"]:
self.device_type = device
else:
raise RuntimeError(f"Not supported device type: {device}")
self.device = torch.device(self.device_type)
84 changes: 84 additions & 0 deletions python/sglang/srt/configs/load_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
import enum
import json
import logging
from dataclasses import dataclass, field
from typing import List, Optional, Union

from sglang.srt.utils import is_hip

logger = logging.getLogger(__name__)


class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"


@dataclass
class LoadConfig:
"""
download_dir: Directory to download and load the weights, default to the
default cache directory of huggingface.
load_format: The format of the model weights to load:
"auto" will try to load the weights in the safetensors format and
fall back to the pytorch bin format if safetensors format is
not available.
"pt" will load the weights in the pytorch bin format.
"safetensors" will load the weights in the safetensors format.
"npcache" will load the weights in pytorch format and store
a numpy cache to speed up the loading.
"dummy" will initialize the weights with random values, which is
mainly for profiling.
"bitsandbytes" will load nf4 type weights.
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
"""

load_format: Union[str, LoadFormat] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
ignore_patterns: Optional[Union[List[str], str]] = None

def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(model_loader_extra_config)
self._verify_load_format()

if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns,
)
else:
self.ignore_patterns = ["original/**/*"]

def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return

load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)

rocm_not_supported_load_format: List[str] = []
if is_hip() and load_format in rocm_not_supported_load_format:
rocm_supported_load_format = [
f
for f in LoadFormat.__members__
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load formats are "
f"{rocm_supported_load_format}"
)
165 changes: 161 additions & 4 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import json
import logging
from enum import IntEnum, auto
from typing import List, Optional
from typing import List, Optional, Union

import torch
from transformers import PretrainedConfig

from sglang.srt.hf_transformers_utils import get_config, get_context_length
from sglang.srt.utils import get_bool_env_var
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.utils import get_bool_env_var, is_hip

logger = logging.getLogger(__name__)

Expand All @@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
class ModelConfig:
def __init__(
self,
path: str,
model_path: str,
trust_remote_code: bool = True,
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
) -> None:
self.model_path = model_path
self.revision = revision
self.quantization = quantization
# Parse args
self.model_override_args = json.loads(model_override_args)
self.hf_config = get_config(
path,
model_path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
Expand All @@ -56,6 +63,7 @@ def __init__(
)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)

# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
Expand Down Expand Up @@ -116,6 +124,8 @@ def __init__(
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size

self._verify_quantization()

# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
Expand Down Expand Up @@ -174,6 +184,86 @@ def get_num_kv_heads(self, tensor_parallel_size) -> int:
# parallel size so each GPU has at least one KV head.
return max(1, total_num_kv_heads // tensor_parallel_size)

# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None:
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg

# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = [
"awq",
"gptq",
"fp8",
"compressed_tensors",
"compressed-tensors",
"fbgemm_fp8",
]
optimized_quantization_methods = [
"fp8",
"marlin",
"modelopt",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"fbgemm_fp8",
"compressed_tensors",
"compressed-tensors",
"experts_int8",
]
if self.quantization is not None:
self.quantization = self.quantization.lower()

# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()

if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()

# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization
)
if quantization_override:
quant_method = quantization_override
self.quantization = quantization_override
break

# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization})."
)

if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}."
)
if is_hip() and self.quantization not in rocm_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in ROCm."
)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.",
self.quantization,
)


def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
Expand All @@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config

if hasattr(config, "text_config"):
Expand All @@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
return config


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,
"float16": torch.float16,
"float": torch.float32,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}


# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype(
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
if config_dtype is None:
config_dtype = torch.float32

if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
if config_dtype == torch.float32:
if config.model_type == "gemma2":
logger.info(
"For Gemma 2, we downcast float32 to bfloat16 instead "
"of float16 by default. Please specify `dtype` if you "
"want to use float16."
)
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")

# Verify the dtype.
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)

return torch_dtype


def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
Expand Down
13 changes: 5 additions & 8 deletions python/sglang/srt/configs/qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,10 @@ def __init__(
self.attention_dropout = attention_dropout
self.rope_scaling = rope_scaling

# NOTE: the following section from original transformers config
# for Qwen2-VL is commented out to address rope config loading issue
#
# if self.rope_scaling is not None and "type" in self.rope_scaling:
# if self.rope_scaling["type"] == "mrope":
# self.rope_scaling["type"] = "default"
# self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)
# NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
if self.rope_scaling is not None and "type" in self.rope_scaling:
if self.rope_scaling["type"] == "mrope":
self.rope_scaling["type"] = "default"
self.rope_scaling["rope_type"] = self.rope_scaling["type"]

super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
2 changes: 2 additions & 0 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def get_config(
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr(config, "_name_or_path", model)
if model_override_args:
config.update(model_override_args)

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"Fp8LinearMethod",
"MarlinLinearMethod",
"GPTQLinearMethod",
"QQQLinearMethod",
]


Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.loader import DefaultModelLoader

from sglang.srt.layers.linear import (
ColumnParallelLinear,
Expand All @@ -40,6 +39,7 @@
RowParallelLinear,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_loader.loader import DefaultModelLoader


class BaseLayerWithLoRA(nn.Module):
Expand Down
Loading

0 comments on commit 9255020

Please sign in to comment.