Skip to content

Commit

Permalink
Merge remote-tracking branch 'paddlenlp/develop' into dev_20250126_ad…
Browse files Browse the repository at this point in the history
…d_pipeline_for_moe
  • Loading branch information
DrownFish19 committed Feb 1, 2025
2 parents 3fcf2c1 + 96856bd commit 3a320cb
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
| [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b |
| [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct |
| [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base |
| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B |
| [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it |
| [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 |
| [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 |
Expand Down
6 changes: 3 additions & 3 deletions paddlenlp/transformers/conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,15 +1311,15 @@ def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True)
def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False):
# state_keys_map base to real
state_keys_map = {}

state_keys_base = set(state_keys_base)
# sorted by length,match from long to short for A.key B.key ...
state_keys_base = sorted(state_keys_base, key=lambda x: len(x), reverse=True)
state_keys_real = set(state_keys_real)

for key in state_keys_base:
for x in state_keys_real:
if x.endswith(key):
state_keys_map[key] = x
# break # remove break for math A.key B.key ...
break
if key not in state_keys_map:
if not ignore_error:
logger.debug(f"tensor parallel conversion: could not find name {key} in loaded state dict!")
Expand Down
34 changes: 30 additions & 4 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):


class DeepseekV2MLP(nn.Layer):
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None):
def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
Expand All @@ -588,7 +588,7 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size
ColumnParallelLinear = linear_utils.ColumnParallelLinear
RowParallelLinear = linear_utils.RowParallelLinear

if config.tensor_parallel_degree > 1:
if config.tensor_parallel_degree > 1 and not is_moe:
self.gate_proj = ColumnParallelLinear(
self.hidden_size,
self.intermediate_size,
Expand Down Expand Up @@ -712,7 +712,7 @@ def __init__(self, config: DeepseekV2Config):
self.alpha = config.aux_loss_alpha

Check warning on line 712 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L712

Added line #L712 was not covered by tests
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size)
self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=True)

def forward(self, hidden_states):
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states)
Expand Down Expand Up @@ -1110,7 +1110,8 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi
["embed_tokens.weight"],
["norm.weight"],
]
for layer_index in range(config.num_hidden_layers):
# last one layer contains MTP (eagle) parameters for inference
for layer_index in range(config.num_hidden_layers + config.num_nextn_predict_layers):
layer_mappings = [
[f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"],
[f"layers.{layer_index}.self_attn.q_a_proj.weight", None, "transpose"],
Expand All @@ -1130,6 +1131,7 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi

# MoE parameters
model_mappings.append([f"layers.{layer_index}.mlp.gate.weight", None, "transpose"])
model_mappings.append([f"layers.{layer_index}.mlp.gate.e_score_correction_bias"])
for expert_idx in range(config.n_routed_experts):
expert_mappings = [
[f"layers.{layer_index}.mlp.experts.{expert_idx}.gate_proj.weight", None, "transpose"],
Expand All @@ -1141,6 +1143,15 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi
model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.up_proj.weight", None, "transpose"])
model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.down_proj.weight", None, "transpose"])

# MTP (eagle) parameters for inference
if layer_index >= config.num_hidden_layers:
model_mappings.append([f"layers.{layer_index}.embed_tokens.weight"])
model_mappings.append([f"layers.{layer_index}.enorm.weight"])
model_mappings.append([f"layers.{layer_index}.hnorm.weight"])
model_mappings.append([f"layers.{layer_index}.eh_proj.weight", None, "transpose"])
model_mappings.append([f"layers.{layer_index}.shared_head.norm.weight"])
model_mappings.append([f"layers.{layer_index}.shared_head.head.weight", None, "transpose"])

init_name_mappings(mappings=model_mappings)
if cls.base_model_class.__name__ not in config.architectures:
for mapping in model_mappings:
Expand Down Expand Up @@ -1203,6 +1214,21 @@ def get_tensor_parallel_split_mappings(num_layers):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
final_actions[key] = action

# for MTP (eagle) parameters for inference
base_actions.pop("embed_tokens.weight")
base_actions.pop("lm_head.weight")
base_actions["layers.0.embed_tokens.weight"] = partial(fn, is_column=False)
base_actions["layers.0.eh_proj.weight"] = partial(fn, is_column=True)
base_actions["layers.0.shared_head.head.weight"] = partial(fn, is_column=True)
for key, action in base_actions.items():
if "layers.0." in key:
for i in range(
config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers
):
final_actions[key.replace("layers.0.", f"layers.{i}.")] = action
else:
final_actions[key] = action

return final_actions

mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)
Expand Down

0 comments on commit 3a320cb

Please sign in to comment.