Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengy001 committed Nov 27, 2024
1 parent 0181fd3 commit 75d9978
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 10 deletions.
4 changes: 2 additions & 2 deletions python/sglang/srt/hf_transformers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ def get_config(
# Special architecture mapping check for GGUF models
if gguf_file is not None:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(
f"Can't get gguf config for {config.model_type}.")
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
config.update({"architectures": [model_type]})

Expand Down Expand Up @@ -126,6 +125,7 @@ def get_context_length(config):
# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file.
_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer"


def get_tokenizer(
tokenizer_name: str,
*args,
Expand Down
11 changes: 9 additions & 2 deletions python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,15 @@ def __init__(
prefix: str = "",
enable_tp: bool = True,
):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config, prefix)
super().__init__(
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
)

self.enable_tp = enable_tp
if self.enable_tp:
Expand Down
6 changes: 5 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,11 @@ def setup_model(self):

def get_model_config_params(self):
sig = inspect.signature(VllmModelConfig.__init__)
model = self.server_args.gguf_file if self.server_args.load_format == "gguf" else self.server_args.model_path
model = (
self.server_args.gguf_file
if self.server_args.load_format == "gguf"
else self.server_args.model_path
)
params = {
"model": model,
"quantization": self.server_args.quantization,
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def __init__(
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)

Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def __post_init__(self):
if self.gguf_file is not None:
self.quantization = self.load_format = "gguf"


@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
Expand Down
18 changes: 15 additions & 3 deletions test/srt/test_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,21 @@

class TestGGUF(unittest.TestCase):
def test_load_model(self):
model_path = str(Path(hf_hub_download("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
filename="tinyllama-1.1b-chat-v1.0.Q2_K.gguf")).parent)
model_path = str(
Path(
hf_hub_download(
"TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF",
filename="tinyllama-1.1b-chat-v1.0.Q2_K.gguf",
)
).parent
)

server_args = ServerArgs(model_path=model_path, random_seed=42, disable_radix_cache=True, load_format="auto")
server_args = ServerArgs(
model_path=model_path,
random_seed=42,
disable_radix_cache=True,
load_format="auto",
)
self.assertEqual(server_args.model_path, model_path)
self.assertEqual(server_args.load_format, "gguf")
self.assertIsNotNone(server_args.gguf_file)
Expand Down Expand Up @@ -45,5 +56,6 @@ def test_load_model(self):
self.assertIsNotNone(tokenizer.vocab_file)
self.assertEqual(Path(tokenizer.vocab_file).suffix, ".gguf")


if __name__ == "__main__":
unittest.main()

0 comments on commit 75d9978

Please sign in to comment.