Skip to content

Commit

Permalink
Update modelopt layer spec for Mixtral (#10660)
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
  • Loading branch information
janekl authored Sep 27, 2024
1 parent fdaf607 commit 4f59502
Showing 1 changed file with 16 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

Expand Down Expand Up @@ -48,6 +49,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:
if not HAVE_MEGATRON_CORE:
raise IMPORT_ERROR

mlp = _get_mlp_module_spec(num_experts=num_experts)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
Expand All @@ -65,7 +67,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=TENorm,
mlp=_get_mlp_module_spec(num_experts=num_experts),
mlp=mlp,
mlp_bda=get_bias_dropout_add,
# Map TE-layernorm-fusion keys back
sharded_state_dict_keys_map={
Expand All @@ -77,7 +79,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec:


# Helper function to get module spec for MLP/MoE
def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec:
def _get_mlp_module_spec(num_experts: int = None) -> ModuleSpec:
if num_experts is None:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
Expand All @@ -91,12 +93,18 @@ def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False
# Mixture of experts with modules in megatron core.
return ModuleSpec(
module=MoELayer,
submodules=(
MLPSubmodules(
submodules=MoESubmodules(
experts=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
)
if not moe_grouped_gemm
else None
),
shared_experts=ModuleSpec(
module=SharedExpertMLP,
params={"gate": False},
submodules=MLPSubmodules(
linear_fc1=ColumnParallelLinear,
linear_fc2=RowParallelLinear,
),
),
),
)

0 comments on commit 4f59502

Please sign in to comment.