Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MoE] fix expert parallel #9760

Merged
merged 3 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ def main():
quantization_config=quantization_config,
)

if "Qwen2Moe" in str(model_config.architectures) and training_args.data_parallel_degree > 1:
training_args.use_expert_parallel = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不太好吧,万一用户的dp跟 expert_parallel degree对不上怎么办?

Copy link
Contributor Author

@DesmonDay DesmonDay Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moe layer写法就默认了是这个逻辑,expert_parallel_degree=dp_degree。觉得不好的话我把原来的逻辑改掉?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉只能默认和data_parallel_degree进行对齐,做all-to-all时是应该在数据并行组内进行


LlmMetaConfig.set_llm_config(model_config, training_args)
model_config.use_fast_layer_norm = model_args.use_fast_layer_norm

Expand Down
3 changes: 3 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ def main():
except:
print("Not register llama pp reshard information.")

if "Qwen2Moe" in str(config.architectures) and training_args.data_parallel_degree > 1:
training_args.use_expert_parallel = True

if model_args.continue_training:
# NOTE(gongenlei): new add
if training_args.autotuner_benchmark:
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/transformers/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,14 @@
self.moe_num_experts_per_device = self._parse_moe_expert_parallel(
self.moe_num_experts, self.expert_parallel_degree
)
self.is_dummy_moe = False if self.expert_parallel_degree > 1 else True

Check warning on line 165 in paddlenlp/transformers/moe_layer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/moe_layer.py#L165

Added line #L165 was not covered by tests
else:
# when moe_group is dummy, we don't need to use all_to_all
self.moe_group = None
self.moe_rank = 0
self.expert_parallel_degree = 1
self.moe_num_experts_per_device = self.moe_num_experts
self.is_dummy_moe = True

self.all_to_all_dropout = all_to_all_dropout
self.enable_recompute = False
Expand All @@ -181,6 +183,7 @@

self.gate = gate
self.gate.group = self.moe_group
self._post_init()

def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree):
assert (
Expand Down