Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix after QWen support #82

Merged
merged 2 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,10 @@ def match_llama2_chat(model_path: str):

@register_chat_template_matching_function
def match_chat_ml(model_path: str):
if "tinyllama" in model_path.lower():
model_path = model_path.lower()
if "tinyllama" in model_path:
return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")


Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ async def handle_loop(self):
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
)
first_token = first_token.decode("utf-8")
if not isinstance(first_token, str):
first_token = first_token.decode("utf-8")
if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i]

Expand Down
79 changes: 39 additions & 40 deletions python/sglang/srt/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from torch import nn
from vllm.transformers_utils.configs.qwen import QWenConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
Expand All @@ -26,9 +25,10 @@
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.transformers_utils.configs.qwen import QWenConfig

class QWenMLP(nn.Module):

class QWenMLP(nn.Module):
def __init__(
self,
hidden_size: int,
Expand All @@ -49,8 +49,10 @@ def __init__(
input_is_parallel=True,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()

def forward(self, x):
Expand All @@ -59,31 +61,28 @@ def forward(self, x):
x, _ = self.c_proj(x)
return x

class QWenAttention(nn.Module):

def __init__(self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None):
class QWenAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
self.head_dim = hidden_size // self.total_num_heads

# pylint: disable=invalid-name
self.c_attn = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True
hidden_size, self.head_dim, self.total_num_heads, bias=True
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
Expand Down Expand Up @@ -120,20 +119,22 @@ def forward(
output, _ = self.c_proj(attn_output)
return output

class QWenBlock(nn.Module):

def __init__(self, config: QWenConfig,layer_id):
class QWenBlock(nn.Module):
def __init__(self, config: QWenConfig, layer_id):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
layer_id=layer_id)
self.attn = QWenAttention(
config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
layer_id=layer_id,
)

self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

Expand Down Expand Up @@ -161,10 +162,10 @@ def forward(
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states

class QWenModel(nn.Module):

def __init__(self, config:QWenConfig):

class QWenModel(nn.Module):
def __init__(self, config: QWenConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
Expand All @@ -175,7 +176,8 @@ def __init__(self, config:QWenConfig):
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config, i) for i in range(config.num_hidden_layers)])
[QWenBlock(config, i) for i in range(config.num_hidden_layers)]
)
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

def forward(
Expand All @@ -195,26 +197,23 @@ def forward(
hidden_states = self.ln_f(hidden_states)
return hidden_states

class QWenLMHeadModel(nn.Module):

def __init__(self, config: QWenConfig,linear_method=None):
class QWenLMHeadModel(nn.Module):
def __init__(self, config: QWenConfig, linear_method=None):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(
vocab_size,
config.hidden_size
)
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata
input_metadata: InputMetadata,
):
hidden_states = self.transformer(input_ids, positions,input_metadata)
hidden_states = self.transformer(input_ids, positions, input_metadata)
next_tokens = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,4 @@ def load_image(image_file):
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))

return image
return image