Skip to content

Commit

Permalink
fix config
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 committed Jul 25, 2024
1 parent 97e0f7d commit d558115
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions python/sglang/srt/managers/controller/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
Expand All @@ -22,6 +23,7 @@
init_distributed_environment,
initialize_model_parallel,
)
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.models import ModelRegistry

from sglang.global_config import global_config
Expand All @@ -38,6 +40,17 @@
logger = logging.getLogger("srt.model_runner")


def is_llama3_405b_fp8(model_config):
if (model_config.hf_config.architectures[0] == "LlamaForCausalLM" and
model_config.hf_config.hidden_size == 16384 and
model_config.hf_config.intermediate_size == 53248 and
model_config.hf_config.num_hidden_layers == 126 and
model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
):
return True
return False


class ModelRunner:
def __init__(
self,
Expand Down Expand Up @@ -118,6 +131,9 @@ def load_model(self):
seed=42,
skip_tokenizer_init=True,
)
if is_llama3_405b_fp8(self.model_config):
self.model_config.hf_config.num_key_value_heads = 8
vllm_model_config.hf_config.num_key_value_heads = 8
self.dtype = vllm_model_config.dtype
if self.model_config.model_overide_args is not None:
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
Expand Down Expand Up @@ -370,5 +386,32 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
return model_arch_name_to_cls[model_arch]


def get_original_weight(loaded_weight, head_dim):
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
dim = loaded_weight.shape[1]
for i in range(n_kv_head):
loaded_weight[i * head_dim: (i + 1) * head_dim, :] = loaded_weight[2 * i * head_dim: (2 * i + 1) * head_dim, :]
original_kv_weight = loaded_weight[:n_kv_head * head_dim, :]
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
return original_kv_weight


def get_weight_loader_srt(weight_loader):
def weight_loader_srt(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None
):
if loaded_shard_id in ["k", "v"] and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2:
loaded_weight = get_original_weight(loaded_weight, self.head_size)

weight_loader(self, param, loaded_weight, loaded_shard_id)

return weight_loader_srt


# Monkey patch model loader
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
original_weight_loader = QKVParallelLinear.weight_loader
setattr(QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader))

0 comments on commit d558115

Please sign in to comment.