From 33fc64eeb8970c4597bc1eb1abb89829107891c6 Mon Sep 17 00:00:00 2001 From: "Yang Zheng(SW)(Alex)" Date: Tue, 26 Nov 2024 03:49:46 +0000 Subject: [PATCH 01/11] Support GGUF format --- docs/requirements.txt | 1 + python/sglang/srt/hf_transformers_utils.py | 25 +++++++++- .../srt/layers/vocab_parallel_embedding.py | 8 ++- .../sglang/srt/model_executor/model_runner.py | 4 +- python/sglang/srt/models/llama.py | 3 +- python/sglang/srt/server_args.py | 17 ++++++- python/sglang/srt/utils.py | 1 + test/srt/test_gguf.py | 49 +++++++++++++++++++ 8 files changed, 99 insertions(+), 9 deletions(-) create mode 100644 test/srt/test_gguf.py diff --git a/docs/requirements.txt b/docs/requirements.txt index 171d60e0a9..948d5427ce 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -15,3 +15,4 @@ sphinx-copybutton sphinx-tabs sphinxcontrib-mermaid urllib3<2.0.0 +gguf>=0.10.0 diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index a5b566fb3d..e3c8d16c39 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -16,6 +16,7 @@ import contextlib import os import warnings +from pathlib import Path from typing import Dict, Optional, Type, Union from huggingface_hub import snapshot_download @@ -27,6 +28,9 @@ PreTrainedTokenizer, PreTrainedTokenizerFast, ) +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + +from sglang.srt.utils import get_gguf_file_if_exist try: from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig @@ -60,15 +64,29 @@ def get_config( trust_remote_code: bool, revision: Optional[str] = None, model_override_args: Optional[dict] = None, + **kwargs, ): + gguf_file = get_gguf_file_if_exist(model) + if gguf_file is not None: + kwargs["gguf_file"] = gguf_file + config = AutoConfig.from_pretrained( - model, trust_remote_code=trust_remote_code, revision=revision + model, trust_remote_code=trust_remote_code, revision=revision, **kwargs ) if config.model_type in _CONFIG_REGISTRY: config_class = _CONFIG_REGISTRY[config.model_type] config = config_class.from_pretrained(model, revision=revision) if model_override_args: config.update(model_override_args) + + # Special architecture mapping check for GGUF models + if gguf_file is not None: + if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + raise RuntimeError( + f"Can't get gguf config for {config.model_type}.") + model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] + config.update({"architectures": [model_type]}) + return config @@ -108,7 +126,6 @@ def get_context_length(config): # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" - def get_tokenizer( tokenizer_name: str, *args, @@ -123,6 +140,10 @@ def get_tokenizer( raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False + gguf_file = get_gguf_file_if_exist(tokenizer_name) + if gguf_file is not None: + kwargs["gguf_file"] = gguf_file + try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index a2d15fc781..f20eeadc72 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -12,6 +12,9 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding as VllmVocabParallelEmbedding, +) from vllm.model_executor.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( @@ -171,7 +174,7 @@ def get_masked_input_and_mask( return input_, ~vocab_mask -class VocabParallelEmbedding(torch.nn.Module): +class VocabParallelEmbedding(VllmVocabParallelEmbedding): """Embedding parallelized in the vocabulary dimension. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to @@ -221,7 +224,8 @@ def __init__( prefix: str = "", enable_tp: bool = True, ): - super().__init__() + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings, padding_size, quant_config, prefix) self.enable_tp = enable_tp if self.enable_tp: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index da311c7ec2..e143007f0e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -262,8 +262,9 @@ def setup_model(self): def get_model_config_params(self): sig = inspect.signature(VllmModelConfig.__init__) + model = self.server_args.gguf_file if self.server_args.load_format == "gguf" else self.server_args.model_path params = { - "model": self.server_args.model_path, + "model": model, "quantization": self.server_args.quantization, "tokenizer": None, "tokenizer_mode": None, @@ -293,7 +294,6 @@ def load_model(self): self.server_args.dtype = "float16" if torch.cuda.get_device_capability()[1] < 5: raise RuntimeError("SGLang only supports sm75 and above.") - # Prepare the vllm model config self.load_config = LoadConfig( load_format=self.server_args.load_format, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 7e9fd0f726..8d898d9489 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -255,6 +255,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.layers = make_layers( config.num_hidden_layers, @@ -302,7 +303,7 @@ def __init__( self.quant_config = quant_config self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index be470dac35..f293bf5dac 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -22,6 +22,7 @@ from sglang.srt.utils import ( get_amdgpu_memory_capacity, + get_gguf_file_if_exist, get_nvgpu_memory_capacity, is_flashinfer_available, is_hip, @@ -136,6 +137,9 @@ class ServerArgs: num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False + # GGUF + gguf_file: Optional[str] = None + def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -204,6 +208,13 @@ def __post_init__(self): "Overlap schedule is disabled." ) + # GGUF + if self.load_format == "auto" or self.load_format == "gguf": + self.gguf_file = get_gguf_file_if_exist(self.model_path) + if self.gguf_file is not None: + self.quantization = self.load_format = "gguf" + + @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args @@ -243,7 +254,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--load-format", type=str, default=ServerArgs.load_format, - choices=["auto", "pt", "safetensors", "npcache", "dummy"], + choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' "and fall back to the pytorch bin format if safetensors format " @@ -253,7 +264,8 @@ def add_cli_args(parser: argparse.ArgumentParser): '"npcache" will load the weights in pytorch format and store ' "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values, ' - "which is mainly for profiling.", + "which is mainly for profiling." + '"gguf" will load the weights in the gguf format. ', ) parser.add_argument( "--trust-remote-code", @@ -293,6 +305,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "gptq_marlin", "awq_marlin", "bitsandbytes", + "gguf", ], help="The quantization method.", ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 46b4db8e88..035c2e9962 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -32,6 +32,7 @@ import warnings from importlib.metadata import PackageNotFoundError, version from io import BytesIO +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union import numpy as np diff --git a/test/srt/test_gguf.py b/test/srt/test_gguf.py new file mode 100644 index 0000000000..3591968b42 --- /dev/null +++ b/test/srt/test_gguf.py @@ -0,0 +1,49 @@ +import unittest +from pathlib import Path + +from huggingface_hub import hf_hub_download + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.server_args import ServerArgs + + +class TestGGUF(unittest.TestCase): + def test_load_model(self): + model_path = str(Path(hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", + filename="tinyllama-1.1b-chat-v1.0.Q2_K.gguf")).parent) + + server_args = ServerArgs(model_path=model_path, random_seed=42, disable_radix_cache=True, load_format="auto") + self.assertEqual(server_args.model_path, model_path) + self.assertEqual(server_args.load_format, "gguf") + self.assertIsNotNone(server_args.gguf_file) + + model_config = ModelConfig( + server_args.model_path, + trust_remote_code=server_args.trust_remote_code, + context_length=server_args.context_length, + model_override_args=server_args.json_model_override_args, + ) + model_runner = ModelRunner( + model_config=model_config, + mem_fraction_static=server_args.mem_fraction_static, + gpu_id=0, + tp_rank=0, + tp_size=server_args.tp_size, + nccl_port=8000, + server_args=server_args, + ) + self.assertEqual(model_runner.vllm_model_config.model, server_args.gguf_file) + self.assertEqual(model_runner.vllm_model_config.quantization, "gguf") + + tokenizer = get_tokenizer( + server_args.tokenizer_path, + tokenizer_mode=server_args.tokenizer_mode, + trust_remote_code=server_args.trust_remote_code, + ) + self.assertIsNotNone(tokenizer.vocab_file) + self.assertEqual(Path(tokenizer.vocab_file).suffix, ".gguf") + +if __name__ == "__main__": + unittest.main() From ec2dd82a46ae48c0639a0413d77eee0bc7e54a5c Mon Sep 17 00:00:00 2001 From: "Yang Zheng(SW)(Alex)" Date: Wed, 27 Nov 2024 07:25:16 +0000 Subject: [PATCH 02/11] Update test --- test/srt/run_suite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index c04a1671ed..d23d52516c 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -38,6 +38,7 @@ "test_update_weights.py", "test_vision_openai_server.py", "test_session_control.py", + "test_gguf.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True From 28328bf035d7df8a28f8c8ab1263abda86eff84f Mon Sep 17 00:00:00 2001 From: zhengy001 Date: Wed, 27 Nov 2024 08:07:22 +0000 Subject: [PATCH 03/11] Format --- python/sglang/srt/hf_transformers_utils.py | 4 ++-- .../srt/layers/vocab_parallel_embedding.py | 11 +++++++++-- .../sglang/srt/model_executor/model_runner.py | 6 +++++- python/sglang/srt/models/llama.py | 4 +++- python/sglang/srt/server_args.py | 1 - test/srt/test_gguf.py | 18 +++++++++++++++--- 6 files changed, 34 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index e3c8d16c39..8d4d8770f7 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -82,8 +82,7 @@ def get_config( # Special architecture mapping check for GGUF models if gguf_file is not None: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - raise RuntimeError( - f"Can't get gguf config for {config.model_type}.") + raise RuntimeError(f"Can't get gguf config for {config.model_type}.") model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) @@ -126,6 +125,7 @@ def get_context_length(config): # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" + def get_tokenizer( tokenizer_name: str, *args, diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index f20eeadc72..da98081bbe 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -224,8 +224,15 @@ def __init__( prefix: str = "", enable_tp: bool = True, ): - super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size, quant_config, prefix) + super().__init__( + num_embeddings, + embedding_dim, + params_dtype, + org_num_embeddings, + padding_size, + quant_config, + prefix, + ) self.enable_tp = enable_tp if self.enable_tp: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e143007f0e..09e1935412 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -262,7 +262,11 @@ def setup_model(self): def get_model_config_params(self): sig = inspect.signature(VllmModelConfig.__init__) - model = self.server_args.gguf_file if self.server_args.load_format == "gguf" else self.server_args.model_path + model = ( + self.server_args.gguf_file + if self.server_args.load_format == "gguf" + else self.server_args.model_path + ) params = { "model": model, "quantization": self.server_args.quantization, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 8d898d9489..6f09f33c06 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -303,7 +303,9 @@ def __init__( self.quant_config = quant_config self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f293bf5dac..27cffa9bdd 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -214,7 +214,6 @@ def __post_init__(self): if self.gguf_file is not None: self.quantization = self.load_format = "gguf" - @staticmethod def add_cli_args(parser: argparse.ArgumentParser): # Model and port args diff --git a/test/srt/test_gguf.py b/test/srt/test_gguf.py index 3591968b42..d917a775bf 100644 --- a/test/srt/test_gguf.py +++ b/test/srt/test_gguf.py @@ -11,10 +11,21 @@ class TestGGUF(unittest.TestCase): def test_load_model(self): - model_path = str(Path(hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", - filename="tinyllama-1.1b-chat-v1.0.Q2_K.gguf")).parent) + model_path = str( + Path( + hf_hub_download( + "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", + filename="tinyllama-1.1b-chat-v1.0.Q2_K.gguf", + ) + ).parent + ) - server_args = ServerArgs(model_path=model_path, random_seed=42, disable_radix_cache=True, load_format="auto") + server_args = ServerArgs( + model_path=model_path, + random_seed=42, + disable_radix_cache=True, + load_format="auto", + ) self.assertEqual(server_args.model_path, model_path) self.assertEqual(server_args.load_format, "gguf") self.assertIsNotNone(server_args.gguf_file) @@ -45,5 +56,6 @@ def test_load_model(self): self.assertIsNotNone(tokenizer.vocab_file) self.assertEqual(Path(tokenizer.vocab_file).suffix, ".gguf") + if __name__ == "__main__": unittest.main() From 8f42879c44fad3eb69330abeca58ae9f5a7793af Mon Sep 17 00:00:00 2001 From: zhengy001 Date: Wed, 27 Nov 2024 13:03:41 +0000 Subject: [PATCH 04/11] ServerArgs: remove gguf_file --- python/sglang/srt/hf_transformers_utils.py | 19 ++++++++++--------- .../sglang/srt/model_executor/model_runner.py | 8 ++------ python/sglang/srt/server_args.py | 14 ++++++-------- python/sglang/srt/utils.py | 1 - test/srt/test_gguf.py | 13 +++---------- 5 files changed, 21 insertions(+), 34 deletions(-) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 8d4d8770f7..4b5ffbb502 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -29,8 +29,7 @@ PreTrainedTokenizerFast, ) from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES - -from sglang.srt.utils import get_gguf_file_if_exist +from vllm.transformers_utils.utils import check_gguf_file try: from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig @@ -66,9 +65,10 @@ def get_config( model_override_args: Optional[dict] = None, **kwargs, ): - gguf_file = get_gguf_file_if_exist(model) - if gguf_file is not None: - kwargs["gguf_file"] = gguf_file + is_gguf = check_gguf_file(model) + if is_gguf: + kwargs["gguf_file"] = model + model = Path(model).parent config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code, revision=revision, **kwargs @@ -80,7 +80,7 @@ def get_config( config.update(model_override_args) # Special architecture mapping check for GGUF models - if gguf_file is not None: + if is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: raise RuntimeError(f"Can't get gguf config for {config.model_type}.") model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] @@ -140,9 +140,10 @@ def get_tokenizer( raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False - gguf_file = get_gguf_file_if_exist(tokenizer_name) - if gguf_file is not None: - kwargs["gguf_file"] = gguf_file + is_gguf = check_gguf_file(tokenizer_name) + if is_gguf: + kwargs["gguf_file"] = tokenizer_name + tokenizer_name = Path(tokenizer_name).parent try: tokenizer = AutoTokenizer.from_pretrained( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 09e1935412..da311c7ec2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -262,13 +262,8 @@ def setup_model(self): def get_model_config_params(self): sig = inspect.signature(VllmModelConfig.__init__) - model = ( - self.server_args.gguf_file - if self.server_args.load_format == "gguf" - else self.server_args.model_path - ) params = { - "model": model, + "model": self.server_args.model_path, "quantization": self.server_args.quantization, "tokenizer": None, "tokenizer_mode": None, @@ -298,6 +293,7 @@ def load_model(self): self.server_args.dtype = "float16" if torch.cuda.get_device_capability()[1] < 5: raise RuntimeError("SGLang only supports sm75 and above.") + # Prepare the vllm model config self.load_config = LoadConfig( load_format=self.server_args.load_format, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 27cffa9bdd..d7b2f1ba6e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,9 +20,10 @@ import tempfile from typing import List, Optional +from vllm.transformers_utils.utils import check_gguf_file + from sglang.srt.utils import ( get_amdgpu_memory_capacity, - get_gguf_file_if_exist, get_nvgpu_memory_capacity, is_flashinfer_available, is_hip, @@ -137,9 +138,6 @@ class ServerArgs: num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False - # GGUF - gguf_file: Optional[str] = None - def __post_init__(self): # Set missing default values if self.tokenizer_path is None: @@ -209,10 +207,10 @@ def __post_init__(self): ) # GGUF - if self.load_format == "auto" or self.load_format == "gguf": - self.gguf_file = get_gguf_file_if_exist(self.model_path) - if self.gguf_file is not None: - self.quantization = self.load_format = "gguf" + if ( + self.load_format == "auto" or self.load_format == "gguf" + ) and check_gguf_file(self.model_path): + self.quantization = self.load_format = "gguf" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 035c2e9962..46b4db8e88 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -32,7 +32,6 @@ import warnings from importlib.metadata import PackageNotFoundError, version from io import BytesIO -from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union import numpy as np diff --git a/test/srt/test_gguf.py b/test/srt/test_gguf.py index d917a775bf..9096182039 100644 --- a/test/srt/test_gguf.py +++ b/test/srt/test_gguf.py @@ -11,13 +11,9 @@ class TestGGUF(unittest.TestCase): def test_load_model(self): - model_path = str( - Path( - hf_hub_download( - "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", - filename="tinyllama-1.1b-chat-v1.0.Q2_K.gguf", - ) - ).parent + model_path = hf_hub_download( + "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", + filename="tinyllama-1.1b-chat-v1.0.Q2_K.gguf", ) server_args = ServerArgs( @@ -26,9 +22,7 @@ def test_load_model(self): disable_radix_cache=True, load_format="auto", ) - self.assertEqual(server_args.model_path, model_path) self.assertEqual(server_args.load_format, "gguf") - self.assertIsNotNone(server_args.gguf_file) model_config = ModelConfig( server_args.model_path, @@ -45,7 +39,6 @@ def test_load_model(self): nccl_port=8000, server_args=server_args, ) - self.assertEqual(model_runner.vllm_model_config.model, server_args.gguf_file) self.assertEqual(model_runner.vllm_model_config.quantization, "gguf") tokenizer = get_tokenizer( From 8bf1890248cee4de16219504800c52ca6c949b5e Mon Sep 17 00:00:00 2001 From: zhengy001 Date: Wed, 27 Nov 2024 14:08:41 +0000 Subject: [PATCH 05/11] Pass lm_head to LogitsProcessor --- python/sglang/srt/layers/logits_processor.py | 4 +++- python/sglang/srt/models/baichuan.py | 2 +- python/sglang/srt/models/chatglm.py | 2 +- python/sglang/srt/models/dbrx.py | 2 +- python/sglang/srt/models/deepseek.py | 2 +- python/sglang/srt/models/deepseek_v2.py | 2 +- python/sglang/srt/models/exaone.py | 2 +- python/sglang/srt/models/gpt2.py | 2 +- python/sglang/srt/models/gpt_bigcode.py | 2 +- python/sglang/srt/models/grok.py | 2 +- python/sglang/srt/models/llama.py | 2 +- python/sglang/srt/models/mixtral.py | 2 +- python/sglang/srt/models/mixtral_quant.py | 2 +- python/sglang/srt/models/olmo.py | 2 +- python/sglang/srt/models/olmoe.py | 2 +- python/sglang/srt/models/phi3_small.py | 2 +- python/sglang/srt/models/qwen.py | 2 +- python/sglang/srt/models/qwen2.py | 2 +- python/sglang/srt/models/qwen2_moe.py | 2 +- python/sglang/srt/models/qwen2_vl.py | 2 +- python/sglang/srt/models/stablelm.py | 2 +- python/sglang/srt/models/torch_native_llama.py | 2 +- python/sglang/srt/models/xverse.py | 2 +- python/sglang/srt/models/xverse_moe.py | 2 +- 24 files changed, 26 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index eedd7fe01d..4cfde0bbc0 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -23,6 +23,7 @@ tensor_model_parallel_all_gather, ) +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -163,7 +164,7 @@ def forward( self, input_ids, hidden_states, - weight, + lm_head: VocabParallelEmbedding, logits_metadata: Union[LogitsMetadata, ForwardBatch], ): if isinstance(logits_metadata, ForwardBatch): @@ -178,6 +179,7 @@ def forward( last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_hidden = hidden_states[last_index] + weight = lm_head.weight if hasattr(lm_head, "weight") else lm_head.qweight last_logits = torch.matmul(last_hidden, weight.T) if self.do_tensor_parallel_all_gather: last_logits = tensor_model_parallel_all_gather(last_logits) diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 0e5e3b9ade..b5b45f2b24 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -353,7 +353,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 05ce17a6b1..ced6859c7a 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -378,7 +378,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index b8dad0248a..e9b4ff1417 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -390,7 +390,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index cdebafa2ff..43dfc50a47 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -394,7 +394,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 85467c12c9..55a458c205 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -763,7 +763,7 @@ def forward( hidden_states = self.model(input_ids, positions, forward_batch) if not forward_batch.forward_mode.is_idle(): return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index c097e00ad2..8c244419fe 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -314,7 +314,7 @@ def forward( input_ids, positions, forward_batch, input_embeds ) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 8d988fe8ea..6fbfe9edd7 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -247,7 +247,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 03597fa734..5af1273202 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -271,7 +271,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 1e49eb59aa..d5c303d139 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -304,7 +304,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 6f09f33c06..d610f975c4 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -321,7 +321,7 @@ def forward( hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) else: return self.pooler(hidden_states, forward_batch) diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 98d5ab332a..b2e895f56d 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -310,7 +310,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index d15a389a84..8dba2b722a 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -343,7 +343,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 80fd64a53a..ead15f5a0f 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -306,7 +306,7 @@ def forward( input_embeds=input_embeds, ) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 407eb98cb3..549e2d032e 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -321,7 +321,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 6561153212..b99e0bad40 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -422,7 +422,7 @@ def forward( if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) else: diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 4c18290265..fb4b67ff50 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -260,7 +260,7 @@ def forward( ): hidden_states = self.transformer(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 634ce1cf16..def9370604 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -292,7 +292,7 @@ def forward( hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) else: return self.pooler(hidden_states, forward_batch) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index febd6d7484..2569932693 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -376,7 +376,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index dc58383eee..ab4eb69365 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -668,7 +668,7 @@ def forward( if not get_embedding: return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) else: return self.pooler(hidden_states, forward_batch) diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 9fa2ab3433..38f2be13a4 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -261,7 +261,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index b9451d5915..87146b33a1 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -413,7 +413,7 @@ def forward( ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def get_hidden_dim(self, module_name): diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index fb7e14a0ef..42f51a7fac 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -315,7 +315,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights( diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index c6458f7f50..3a8b9a9e43 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -390,7 +390,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): From 02a56b03b1b0e64e55fef977f1caf47b6179068b Mon Sep 17 00:00:00 2001 From: zhengy001 Date: Thu, 28 Nov 2024 02:51:35 +0000 Subject: [PATCH 06/11] Fix CI --- python/sglang/srt/models/commandr.py | 2 +- python/sglang/srt/models/gemma.py | 2 +- python/sglang/srt/models/gemma2.py | 2 +- python/sglang/srt/models/internlm2.py | 2 +- python/sglang/srt/models/minicpm.py | 8 +++----- python/sglang/srt/models/minicpm3.py | 8 +++----- python/sglang/srt/models/mllama.py | 2 +- python/sglang/srt/models/phi3_small.py | 5 ++++- 8 files changed, 15 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index d4018be88a..8769d49db0 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -339,7 +339,7 @@ def forward( forward_batch, ) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch + input_ids, hidden_states, self.model.embed_tokens, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index a53fad9580..f6d3015468 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -298,7 +298,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch + input_ids, hidden_states, self.model.embed_tokens, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 0fa6a53935..104205648b 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -363,7 +363,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch + input_ids, hidden_states, self.model.embed_tokens, forward_batch ) def get_attention_sliding_window_size(self): diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 59ff6d1e2d..d217fd71ff 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -270,7 +270,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) return self.logits_processor( - input_ids, hidden_states, self.output.weight, forward_batch + input_ids, hidden_states, self.output, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 239cfb6fcc..0d668fe5d1 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -308,12 +308,10 @@ def forward( hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: - lm_head_weight = self.model.embed_tokens.weight + lm_head = self.model.embed_tokens else: - lm_head_weight = self.lm_head.weight - return self.logits_processor( - input_ids, hidden_states, lm_head_weight, forward_batch - ) + lm_head = self.lm_head + return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 6f53f2974f..e6bf118ed2 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -585,12 +585,10 @@ def forward( hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: - lm_head_weight = self.model.embed_tokens.weight + lm_head = self.model.embed_tokens else: - lm_head_weight = self.lm_head.weight - return self.logits_processor( - input_ids, hidden_states, lm_head_weight, forward_batch - ) + lm_head = self.lm_head + return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 63bbfdb7eb..2a0cf4ea31 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -966,7 +966,7 @@ def forward( skip_cross_attention=skip_cross_attention, ) return self.logits_processor( - input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch + input_ids, hidden_states, self.language_model.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index b99e0bad40..e310dfcea0 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -397,10 +397,13 @@ def get_decoder(self): def compute_logits( self, + input_ids: torch.LongTensor, hidden_states: torch.Tensor, sampling_metadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) + logits = self.logits_processor( + input_ids, self.lm_head, hidden_states, sampling_metadata + ) if self.dummy_token_indices is not None and logits is not None: logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) return logits From f1ffb303a6640bee284319f2d9b3908638e18f04 Mon Sep 17 00:00:00 2001 From: zhengy001 Date: Sat, 30 Nov 2024 04:07:23 +0000 Subject: [PATCH 07/11] Fix CI and cleanup tie_word_embeddings --- python/sglang/srt/hf_transformers_utils.py | 14 +++++- python/sglang/srt/layers/logits_processor.py | 17 +++++-- .../srt/layers/vocab_parallel_embedding.py | 16 ++---- .../sglang/srt/model_executor/model_runner.py | 3 ++ python/sglang/srt/models/baichuan.py | 9 ++-- python/sglang/srt/models/llama.py | 24 +++------ python/sglang/srt/models/olmo.py | 5 -- python/sglang/srt/models/qwen2.py | 14 +++--- python/sglang/srt/models/qwen2_vl.py | 2 - .../sglang/srt/models/torch_native_llama.py | 13 ++--- python/sglang/srt/server_args.py | 3 +- python/sglang/srt/utils.py | 23 +++++++++ test/srt/test_gguf.py | 50 ++++--------------- 13 files changed, 91 insertions(+), 102 deletions(-) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 4b5ffbb502..ac475cf34c 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -29,7 +29,6 @@ PreTrainedTokenizerFast, ) from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -from vllm.transformers_utils.utils import check_gguf_file try: from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig @@ -217,3 +216,16 @@ def attach_additional_stop_token_ids(tokenizer): ) else: tokenizer.additional_stop_token_ids = None + + +def check_gguf_file(model: Union[str, os.PathLike]) -> bool: + """Check if the file is a GGUF model.""" + model = Path(model) + if not model.is_file(): + return False + elif model.suffix == ".gguf": + return True + + with open(model, "rb") as f: + header = f.read(4) + return header == b"GGUF" diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 4cfde0bbc0..7a3665c8f8 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -179,8 +179,7 @@ def forward( last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 last_hidden = hidden_states[last_index] - weight = lm_head.weight if hasattr(lm_head, "weight") else lm_head.qweight - last_logits = torch.matmul(last_hidden, weight.T) + last_logits = self._get_logits(last_hidden, lm_head) if self.do_tensor_parallel_all_gather: last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = last_logits[:, : self.config.vocab_size].float() @@ -231,7 +230,7 @@ def forward( # Compute the logits and logprobs for all required tokens states = torch.cat(states, dim=0) - all_logits = torch.matmul(states, weight.T) + all_logits = self._get_logits(states, lm_head) if self.do_tensor_parallel_all_gather: all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = all_logits[:, : self.config.vocab_size].float() @@ -278,6 +277,18 @@ def forward( output_top_logprobs=output_top_logprobs, ) + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if lm_head.quant_config and lm_head.quant_config.get_name() == "gguf": + logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) + else: + logits = torch.matmul(hidden_states, lm_head.weight.T) + return logits + def test(): all_logprobs = torch.tensor( diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index da98081bbe..effea1c6c9 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -12,9 +12,6 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding as VllmVocabParallelEmbedding, -) from vllm.model_executor.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( @@ -174,7 +171,7 @@ def get_masked_input_and_mask( return input_, ~vocab_mask -class VocabParallelEmbedding(VllmVocabParallelEmbedding): +class VocabParallelEmbedding(torch.nn.Module): """Embedding parallelized in the vocabulary dimension. Adapted from torch.nn.Embedding, note that we pad the vocabulary size to @@ -224,15 +221,8 @@ def __init__( prefix: str = "", enable_tp: bool = True, ): - super().__init__( - num_embeddings, - embedding_dim, - params_dtype, - org_num_embeddings, - padding_size, - quant_config, - prefix, - ) + super().__init__() + self.quant_config = quant_config self.enable_tp = enable_tp if self.enable_tp: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index da311c7ec2..04975ffe8b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -62,6 +62,7 @@ enable_show_time_cost, get_available_gpu_memory, is_hip, + monkey_patch_vllm_gguf_config, monkey_patch_vllm_model_config, monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, @@ -300,6 +301,8 @@ def load_model(self): download_dir=self.server_args.download_dir, ) monkey_patch_vllm_model_config() + if self.server_args.load_format == "gguf": + monkey_patch_vllm_gguf_config() self.vllm_model_config = VllmModelConfig(**self.get_model_config_params()) if self.model_config.model_override_args is not None: self.vllm_model_config.hf_config.update( diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index b5b45f2b24..d3b0fd9ae4 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -338,11 +338,12 @@ def __init__( self.quant_config = quant_config self.model = BaiChuanModel(config, position_embedding, quant_config) - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) def forward( diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index d610f975c4..fa0de57f0e 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -303,9 +303,12 @@ def __init__( self.quant_config = quant_config self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -383,12 +386,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) - load_tie_word_embeddings = ( - hasattr(self.config, "tie_word_embeddings") - and self.config.tie_word_embeddings - and "lm_head.weight" in params_dict - ) - for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: continue @@ -421,15 +418,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if load_tie_word_embeddings and name == "model.embed_tokens.weight": - embed_tokens_weight = loaded_weight - - if load_tie_word_embeddings: - # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing - param = self.lm_head.weight - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, embed_tokens_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index ead15f5a0f..2ef6532cec 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -326,11 +326,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - # With tie_word_embeddings, we can skip lm_head.weight - # The weight might appear unnecessarily in the files if the model is - # processed with quantization, LoRA, fine-tuning, etc. - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index def9370604..4c8ddd4b9e 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -230,6 +230,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.layers = make_layers( config.num_hidden_layers, @@ -276,7 +277,12 @@ def __init__( self.config = config self.quant_config = quant_config self.model = Qwen2Model(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @@ -306,6 +312,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: @@ -335,11 +342,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if ( - self.config.tie_word_embeddings - and name == "model.embed_tokens.weight" - ): - weight_loader(params_dict["lm_head.weight"], loaded_weight) EntryClass = Qwen2ForCausalLM diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index ab4eb69365..155bde015c 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -686,8 +686,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - if self.config.tie_word_embeddings and "lm_head.weight" in name: - continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 87146b33a1..68982eebff 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -396,7 +396,10 @@ def __init__( self.torchao_config = global_server_args_dict["torchao_config"] self.supports_torch_tp = True self.model = LlamaModel(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) # turning off autotune for fp8dq since it doesn't give speedup and @@ -501,14 +504,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if ( - hasattr(self.config, "tie_word_embeddings") - and self.config.tie_word_embeddings - ): - # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing - param = self.lm_head.weight - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, self.model.embed_tokens.weight) apply_torchao_config_(self, params_dict, set(["proj.weight"])) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d7b2f1ba6e..e52350490e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,8 +20,7 @@ import tempfile from typing import List, Optional -from vllm.transformers_utils.utils import check_gguf_file - +from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.utils import ( get_amdgpu_memory_capacity, get_nvgpu_memory_capacity, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 46b4db8e88..89044c8b2a 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -557,6 +557,29 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: setattr(GroupCoordinator, "all_gather", all_gather) +def monkey_patch_vllm_gguf_config(): + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.quantization.gguf import ( + GGUFConfig, + GGUFEmbeddingMethod, + GGUFLinearMethod, + ) + + from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding + + def get_quant_method_with_embedding_replaced( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return GGUFLinearMethod(self) + elif isinstance(layer, VocabParallelEmbedding): + # patch to own VocabParallelEmbedding + return GGUFEmbeddingMethod(self) + return None + + setattr(GGUFConfig, "get_quant_method", get_quant_method_with_embedding_replaced) + + def maybe_set_triton_cache_manager() -> None: """Set environment variable to tell Triton to use a custom cache manager""" diff --git a/test/srt/test_gguf.py b/test/srt/test_gguf.py index 9096182039..89572c45f1 100644 --- a/test/srt/test_gguf.py +++ b/test/srt/test_gguf.py @@ -1,53 +1,25 @@ import unittest -from pathlib import Path from huggingface_hub import hf_hub_download -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.server_args import ServerArgs +import sglang as sgl class TestGGUF(unittest.TestCase): - def test_load_model(self): - model_path = hf_hub_download( - "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", - filename="tinyllama-1.1b-chat-v1.0.Q2_K.gguf", - ) + def test_models(self): + prompt = "Today is a sunny day and I like" + sampling_params = {"temperature": 0, "max_new_tokens": 8} - server_args = ServerArgs( - model_path=model_path, - random_seed=42, - disable_radix_cache=True, - load_format="auto", + model_path = hf_hub_download( + "Qwen/Qwen2-1.5B-Instruct-GGUF", + filename="qwen2-1_5b-instruct-q4_k_m.gguf", ) - self.assertEqual(server_args.load_format, "gguf") - model_config = ModelConfig( - server_args.model_path, - trust_remote_code=server_args.trust_remote_code, - context_length=server_args.context_length, - model_override_args=server_args.json_model_override_args, - ) - model_runner = ModelRunner( - model_config=model_config, - mem_fraction_static=server_args.mem_fraction_static, - gpu_id=0, - tp_rank=0, - tp_size=server_args.tp_size, - nccl_port=8000, - server_args=server_args, - ) - self.assertEqual(model_runner.vllm_model_config.quantization, "gguf") + engine = sgl.Engine(model_path=model_path, random_seed=42) + outputs = engine.generate(prompt, sampling_params)["text"] + engine.shutdown() - tokenizer = get_tokenizer( - server_args.tokenizer_path, - tokenizer_mode=server_args.tokenizer_mode, - trust_remote_code=server_args.trust_remote_code, - ) - self.assertIsNotNone(tokenizer.vocab_file) - self.assertEqual(Path(tokenizer.vocab_file).suffix, ".gguf") + self.assertEqual(outputs, " it. I have a lot of work") if __name__ == "__main__": From f027e7dcffeec4cd315b34c9a0454e482b29a9b8 Mon Sep 17 00:00:00 2001 From: zhengy001 Date: Sat, 30 Nov 2024 04:14:17 +0000 Subject: [PATCH 08/11] Update run_suite.py order --- test/srt/run_suite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d23d52516c..71b201ac4b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -15,6 +15,7 @@ "test_double_sparsity.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", + "test_gguf.py", "test_input_embeddings.py", "test_json_constrained.py", "test_large_max_new_tokens.py", @@ -38,7 +39,6 @@ "test_update_weights.py", "test_vision_openai_server.py", "test_session_control.py", - "test_gguf.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True From 8f8ceba32a71027eebc85c11deb667406f56ae24 Mon Sep 17 00:00:00 2001 From: zhengy001 Date: Sat, 30 Nov 2024 05:04:26 +0000 Subject: [PATCH 09/11] Check quant_config properity --- python/sglang/srt/layers/logits_processor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 7a3665c8f8..fec9d42cca 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -283,7 +283,11 @@ def _get_logits( lm_head: VocabParallelEmbedding, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if lm_head.quant_config and lm_head.quant_config.get_name() == "gguf": + if ( + hasattr(lm_head, "quant_config") + and lm_head.quant_config + and lm_head.quant_config.get_name() == "gguf" + ): logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) else: logits = torch.matmul(hidden_states, lm_head.weight.T) From b3eaf4930015627028b485bad9ecabb9edaa6522 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 30 Nov 2024 00:07:35 -0800 Subject: [PATCH 10/11] Update llama.py --- python/sglang/srt/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index fab66d8730..ba5cfc90e0 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -422,7 +422,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - apply_torchao_config_(self, params_dict, set(["proj.weight"])) def get_weights_by_name( From 7d7917239ef6922e20fb8d5018989cad7b020e5a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 30 Nov 2024 00:10:37 -0800 Subject: [PATCH 11/11] Apply suggestions from code review --- python/sglang/srt/layers/logits_processor.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index fec9d42cca..274c4c311e 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -283,14 +283,11 @@ def _get_logits( lm_head: VocabParallelEmbedding, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if ( - hasattr(lm_head, "quant_config") - and lm_head.quant_config - and lm_head.quant_config.get_name() == "gguf" - ): - logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) - else: + if hasattr(lm_head, "weight"): logits = torch.matmul(hidden_states, lm_head.weight.T) + else: + # GGUF models + logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) return logits