diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py index 080984133f3c..046e032093b1 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py @@ -20,7 +20,8 @@ from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules - from megatron.core.transformer.moe.moe_layer import MoELayer + from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules + from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules @@ -48,6 +49,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec: if not HAVE_MEGATRON_CORE: raise IMPORT_ERROR + mlp = _get_mlp_module_spec(num_experts=num_experts) return ModuleSpec( module=TransformerLayer, submodules=TransformerLayerSubmodules( @@ -65,7 +67,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec: ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=TENorm, - mlp=_get_mlp_module_spec(num_experts=num_experts), + mlp=mlp, mlp_bda=get_bias_dropout_add, # Map TE-layernorm-fusion keys back sharded_state_dict_keys_map={ @@ -77,7 +79,7 @@ def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec: # Helper function to get module spec for MLP/MoE -def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec: +def _get_mlp_module_spec(num_experts: int = None) -> ModuleSpec: if num_experts is None: # Dense MLP w/ or w/o TE modules. return ModuleSpec( @@ -91,12 +93,18 @@ def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False # Mixture of experts with modules in megatron core. return ModuleSpec( module=MoELayer, - submodules=( - MLPSubmodules( + submodules=MoESubmodules( + experts=MLPSubmodules( linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear, - ) - if not moe_grouped_gemm - else None + ), + shared_experts=ModuleSpec( + module=SharedExpertMLP, + params={"gate": False}, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, + linear_fc2=RowParallelLinear, + ), + ), ), )