Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suppport qwen model and solve some problems #75

Merged
merged 8 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
- Mixtral
- LLaVA
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
- Qwen
- AWQ quantization

## Benchmark And Performance
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def prefill_forward_triton(self, q, k, v, input_metadata: InputMetadata):
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)

extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
k.contiguous(),
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async def handle_loop(self):
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
)
first_token = first_token.decode("utf-8")
if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i]

Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def load_model(self):
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.models.llava import LlavaLlamaForCausalLM
from sglang.srt.models.mixtral import MixtralForCausalLM
from sglang.srt.models.qwen import QWenLMHeadModel

# Select model class
architectures = getattr(self.model_config.hf_config, "architectures", [])
Expand All @@ -258,6 +259,8 @@ def load_model(self):
if arch == "MixtralForCausalLM":
model_class = MixtralForCausalLM
break
if arch == "QWenLMHeadModel":
model_class = QWenLMHeadModel
if model_class is None:
raise ValueError(f"Unsupported architectures: {architectures}")

Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@ def __init__(
# Unify the config keys for hf_config
self.context_len = get_context_length(self.hf_config)
self.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads
self.num_key_value_heads = self.hf_config.num_key_value_heads
self.num_attention_heads = self.hf_config.num_attention_heads
try:
self.num_key_value_heads = self.hf_config.num_key_value_heads
except Exception as e:
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
self.num_key_value_heads = self.num_attention_heads
self.hidden_size = self.hf_config.hidden_size
self.num_hidden_layers = self.hf_config.num_hidden_layers
self.vocab_size = self.hf_config.vocab_size
261 changes: 261 additions & 0 deletions python/sglang/srt/models/qwen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
from typing import Any, Dict, List, Optional, Tuple

import torch
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from torch import nn
from vllm.transformers_utils.configs.qwen import QWenConfig
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)

class QWenMLP(nn.Module):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str = "silu",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
2 * [intermediate_size],
bias=False,
gather_output=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.c_proj(x)
return x

class QWenAttention(nn.Module):

def __init__(self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
)
self.total_num_heads = num_heads
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads

# pylint: disable=invalid-name
self.c_attn = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
bias=True
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.scaling = self.head_dim**-0.5
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_heads,
layer_id=layer_id,
)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
output, _ = self.c_proj(attn_output)
return output

class QWenBlock(nn.Module):

def __init__(self, config: QWenConfig,layer_id):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
layer_id=layer_id)

self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states

class QWenModel(nn.Module):

def __init__(self, config:QWenConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size

vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(
vocab_size,
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config, i) for i in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
input_metadata,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states

class QWenLMHeadModel(nn.Module):

def __init__(self, config: QWenConfig,linear_method=None):
super().__init__()
self.config = config
self.transformer = QWenModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(
vocab_size,
config.hidden_size
)
self.logits_processor = LogitsProcessor(config)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata
):
hidden_states = self.transformer(input_ids, positions,input_metadata)
next_tokens = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
return next_tokens

_column_parallel_weights = []
_row_parallel_weights = ["c_proj.weight"]

def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w2", 0),
("gate_up_proj", "w1", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
6 changes: 4 additions & 2 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ def get_exception_traceback():
def get_int_token_logit_bias(tokenizer, vocab_size):
from transformers import LlamaTokenizer, LlamaTokenizerFast

# a bug when model's vocab size > tokenizer.vocab_size
vocab_size = tokenizer.vocab_size
logit_bias = np.zeros(vocab_size, dtype=np.float32)
for t_id in range(vocab_size):
ss = tokenizer.decode(t_id).strip()
ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5
# else:
Expand Down Expand Up @@ -214,4 +216,4 @@ def load_image(image_file):
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))

return image
return image