diff --git a/python/sglang/srt/models/gemma2_reward.py b/python/sglang/srt/models/gemma2_reward.py index 5faadf67ff..9aab3ce18e 100644 --- a/python/sglang/srt/models/gemma2_reward.py +++ b/python/sglang/srt/models/gemma2_reward.py @@ -58,43 +58,10 @@ def forward( ), "Gemma2ForSequenceClassification is only used for embedding" hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - scores = self.score(hidden_states) + last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings + scores = self.score(last_token_hidden) - return self.pooler(scores, forward_batch) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - for param_name, shard_name, shard_id in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # lm_head is not used in vllm as it is tied with embed_token. - # To prevent errors, skip loading lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + return EmbeddingPoolerOutput(scores) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): Gemma2ForCausalLM.load_weights(self, weights) diff --git a/python/sglang/srt/models/llama_reward.py b/python/sglang/srt/models/llama_reward.py index 5b68d1d321..e285ad6921 100644 --- a/python/sglang/srt/models/llama_reward.py +++ b/python/sglang/srt/models/llama_reward.py @@ -59,22 +59,13 @@ def forward( ), "LlamaForSequenceClassification is only used for embedding" hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - scores = self.score(hidden_states) + last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings + scores = self.score(last_token_hidden) - return self.pooler(scores, forward_batch) + return EmbeddingPoolerOutput(scores) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - params_dict = dict(self.named_parameters()) - - for name, loaded_weight in weights: - if "classification_head" in name: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - elif "lm_head" in name: - continue - else: - LlamaForCausalLM.load_weights(self, [(name, loaded_weight)]) + return LlamaForCausalLM.load_weights(self, weights) class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification): @@ -127,17 +118,7 @@ def forward( return EmbeddingPoolerOutput(scores) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - params_dict = dict(self.named_parameters()) - - for name, loaded_weight in weights: - if "classification_head" in name: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - elif "lm_head" in name: - continue - else: - LlamaForCausalLM.load_weights(self, [(name, loaded_weight)]) + return super().load_weights(weights) EntryClass = [