Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick Updating modelopt spec for Mixtral (10660) into r2.0.0 #10664

Open
wants to merge 1 commit into
base: r2.0.0
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

Expand Down Expand Up @@ -48,6 +49,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:
if not HAVE_MEGATRON_CORE:
raise IMPORT_ERROR

mlp = _get_mlp_module_spec(num_experts=num_experts)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
Expand All @@ -65,7 +67,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=_get_mlp_module_spec(num_experts=num_experts),
mlp=mlp,
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map={
Expand All @@ -77,7 +79,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:


# Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec:
def _get_mlp_module_spec(num_experts: int = None) -> ModuleSpec:
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
Expand All @@ -91,12 +93,18 @@ def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False
# Mixture of experts with modules in megatron core.
return ModuleSpec(
module=MoELayer,
submodules=(
MLPSubmodules(
submodules=MoESubmodules(
experts=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
)
if not moe_grouped_gemm
else None
),
shared_experts=ModuleSpec(
module=SharedExpertMLP,
params={"gate": False},
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
),
)
Loading