Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Support GGUF format #2215

Merged
merged 14 commits into from
Nov 30, 2024
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ sphinx-copybutton
sphinx-tabs
sphinxcontrib-mermaid
urllib3<2.0.0
gguf>=0.10.0
36 changes: 35 additions & 1 deletion python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import contextlib
import os
import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union

from huggingface_hub import snapshot_download
Expand All @@ -27,6 +28,7 @@
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

try:
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
Expand Down Expand Up @@ -60,15 +62,29 @@ def get_config(
trust_remote_code: bool,
revision: Optional[str] = None,
model_override_args: Optional[dict] = None,
**kwargs,
):
is_gguf = check_gguf_file(model)
if is_gguf:
kwargs["gguf_file"] = model
model = Path(model).parent

config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model, revision=revision)
if model_override_args:
config.update(model_override_args)

# Special architecture mapping check for GGUF models
if is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})

return config


Expand Down Expand Up @@ -123,6 +139,11 @@ def get_tokenizer(
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False

is_gguf = check_gguf_file(tokenizer_name)
if is_gguf:
kwargs["gguf_file"] = tokenizer_name
tokenizer_name = Path(tokenizer_name).parent

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
Expand Down Expand Up @@ -195,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer):
)
else:
tokenizer.additional_stop_token_ids = None


def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
"""Check if the file is a GGUF model."""
model = Path(model)
if not model.is_file():
return False
elif model.suffix == ".gguf":
return True

with open(model, "rb") as f:
header = f.read(4)
return header == b"GGUF"
20 changes: 17 additions & 3 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
tensor_model_parallel_all_gather,
)

from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode


Expand Down Expand Up @@ -163,7 +164,7 @@ def forward(
self,
input_ids,
hidden_states,
weight,
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
):
if isinstance(logits_metadata, ForwardBatch):
Expand All @@ -178,7 +179,7 @@ def forward(
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
last_hidden = hidden_states[last_index]

last_logits = torch.matmul(last_hidden, weight.T)
last_logits = self._get_logits(last_hidden, lm_head)
if self.do_tensor_parallel_all_gather:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size].float()
Expand Down Expand Up @@ -229,7 +230,7 @@ def forward(

# Compute the logits and logprobs for all required tokens
states = torch.cat(states, dim=0)
all_logits = torch.matmul(states, weight.T)
all_logits = self._get_logits(states, lm_head)
if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
Expand Down Expand Up @@ -276,6 +277,19 @@ def forward(
output_top_logprobs=output_top_logprobs,
)

def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T)
else:
# GGUF models
logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
return logits


def test():
all_logprobs = torch.tensor(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def __init__(
enable_tp: bool = True,
):
super().__init__()
self.quant_config = quant_config

self.enable_tp = enable_tp
if self.enable_tp:
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
enable_show_time_cost,
get_available_gpu_memory,
is_hip,
monkey_patch_vllm_gguf_config,
monkey_patch_vllm_model_config,
monkey_patch_vllm_p2p_access_check,
set_cpu_offload_max_bytes,
Expand Down Expand Up @@ -297,6 +298,8 @@ def load_model(self):
download_dir=self.server_args.download_dir,
)
monkey_patch_vllm_model_config()
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
self.vllm_model_config = VllmModelConfig(**self.get_model_config_params())
if self.model_config.model_override_args is not None:
self.vllm_model_config.hf_config.update(
Expand Down
11 changes: 6 additions & 5 deletions python/sglang/srt/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,12 @@ def __init__(

self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)

def forward(
Expand All @@ -353,7 +354,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/commandr.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def forward(
forward_batch,
)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def forward(
hidden_states = self.model(input_ids, positions, forward_batch)
if not forward_batch.forward_mode.is_idle():
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/exaone.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def forward(
input_ids, positions, forward_batch, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
input_ids, hidden_states, self.model.embed_tokens, forward_batch
)

def get_attention_sliding_window_size(self):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def forward(
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.output.weight, forward_batch
input_ids, hidden_states, self.output, forward_batch
)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
27 changes: 8 additions & 19 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def __init__(
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
self.layers = make_layers(
config.num_hidden_layers,
Expand Down Expand Up @@ -305,7 +306,12 @@ def __init__(
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.stacked_params_mapping = [
Expand All @@ -329,7 +335,7 @@ def forward(
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return self.pooler(hidden_states, forward_batch)
Expand Down Expand Up @@ -373,7 +379,6 @@ def get_num_params(self):
return len(params_dict)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
embed_tokens_weight = None
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
Expand All @@ -385,12 +390,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

params_dict = dict(self.named_parameters())

load_tie_word_embeddings = (
hasattr(self.config, "tie_word_embeddings")
and self.config.tie_word_embeddings
and "lm_head.weight" in params_dict
)

for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
continue
Expand Down Expand Up @@ -423,16 +422,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

if load_tie_word_embeddings and name == "model.embed_tokens.weight":
embed_tokens_weight = loaded_weight

if load_tie_word_embeddings:
# Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
param = self.lm_head.weight
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if embed_tokens_weight is not None:
weight_loader(param, embed_tokens_weight)

apply_torchao_config_(self, params_dict, set(["proj.weight"]))

def get_weights_by_name(
Expand Down
8 changes: 3 additions & 5 deletions python/sglang/srt/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,10 @@ def forward(
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
lm_head = self.model.embed_tokens
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
lm_head = self.lm_head
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down
8 changes: 3 additions & 5 deletions python/sglang/srt/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,12 +585,10 @@ def forward(
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
lm_head = self.model.embed_tokens
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, forward_batch
)
lm_head = self.lm_head
return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
Expand Down
Loading
Loading