From 775ca5bfbcb4f2c467074e1d509974976145b14e Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jan 2024 05:09:18 +0000 Subject: [PATCH 1/2] Fix after QWen support --- .../srt/managers/detokenizer_manager.py | 3 +- python/sglang/srt/models/qwen.py | 79 +++++++++---------- python/sglang/srt/utils.py | 2 +- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index b572585af6..6dd3d79a70 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -55,7 +55,8 @@ async def handle_loop(self): first_token = self.tokenizer.convert_ids_to_tokens( int(output_tokens[i][0]) ) - first_token = first_token.decode("utf-8") + if not isinstance(first_token, str): + first_token = first_token.decode("utf-8") if first_token.startswith("▁"): output_strs[i] = " " + output_strs[i] diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index e89bfb48c1..ba59d5bb64 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -5,7 +5,6 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.managers.router.model_runner import InputMetadata from torch import nn -from vllm.transformers_utils.configs.qwen import QWenConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( @@ -26,9 +25,10 @@ default_weight_loader, hf_model_weights_iterator, ) +from vllm.transformers_utils.configs.qwen import QWenConfig -class QWenMLP(nn.Module): +class QWenMLP(nn.Module): def __init__( self, hidden_size: int, @@ -49,8 +49,10 @@ def __init__( input_is_parallel=True, ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -59,31 +61,28 @@ def forward(self, x): x, _ = self.c_proj(x) return x -class QWenAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - max_position_embeddings: int, - layer_id: int = 0, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None): +class QWenAttention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + ): super().__init__() self.hidden_size = hidden_size - tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( - ) + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 - self.num_heads = (self.total_num_heads // - tensor_model_parallel_world_size) + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size self.head_dim = hidden_size // self.total_num_heads # pylint: disable=invalid-name self.c_attn = QKVParallelLinear( - hidden_size, - self.head_dim, - self.total_num_heads, - bias=True + hidden_size, self.head_dim, self.total_num_heads, bias=True ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -120,20 +119,22 @@ def forward( output, _ = self.c_proj(attn_output) return output -class QWenBlock(nn.Module): - def __init__(self, config: QWenConfig,layer_id): +class QWenBlock(nn.Module): + def __init__(self, config: QWenConfig, layer_id): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - self.attn = QWenAttention(config.hidden_size, - config.num_attention_heads, - config.max_position_embeddings, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - layer_id=layer_id) + self.attn = QWenAttention( + config.hidden_size, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + layer_id=layer_id, + ) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -161,10 +162,10 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states - -class QWenModel(nn.Module): - def __init__(self, config:QWenConfig): + +class QWenModel(nn.Module): + def __init__(self, config: QWenConfig): super().__init__() self.config = config self.vocab_size = config.vocab_size @@ -175,7 +176,8 @@ def __init__(self, config:QWenConfig): config.hidden_size, ) self.h = nn.ModuleList( - [QWenBlock(config, i) for i in range(config.num_hidden_layers)]) + [QWenBlock(config, i) for i in range(config.num_hidden_layers)] + ) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( @@ -195,26 +197,23 @@ def forward( hidden_states = self.ln_f(hidden_states) return hidden_states -class QWenLMHeadModel(nn.Module): - def __init__(self, config: QWenConfig,linear_method=None): +class QWenLMHeadModel(nn.Module): + def __init__(self, config: QWenConfig, linear_method=None): super().__init__() self.config = config self.transformer = QWenModel(config) vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ParallelLMHead( - vocab_size, - config.hidden_size - ) + self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, - input_metadata: InputMetadata + input_metadata: InputMetadata, ): - hidden_states = self.transformer(input_ids, positions,input_metadata) + hidden_states = self.transformer(input_ids, positions, input_metadata) next_tokens = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 6822e95211..8c58766025 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -216,4 +216,4 @@ def load_image(image_file): else: image = Image.open(BytesIO(base64.b64decode(image_file))) - return image \ No newline at end of file + return image From 0209346f4d680c5a6d198802dfc2ade08de56696 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 23 Jan 2024 05:16:07 +0000 Subject: [PATCH 2/2] update --- python/sglang/lang/chat_template.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 579cc845b2..5ea9786b88 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -168,7 +168,10 @@ def match_llama2_chat(model_path: str): @register_chat_template_matching_function def match_chat_ml(model_path: str): - if "tinyllama" in model_path.lower(): + model_path = model_path.lower() + if "tinyllama" in model_path: + return get_chat_template("chatml") + if "qwen" in model_path and "chat" in model_path: return get_chat_template("chatml")