Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mlp): enhance mlp_layer_fusion #382

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions internlm/core/context/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def update(self, config):
self._add_item(k, v)
return self

def __delattr__(self, key):
if key in self:
super().__delitem__(key)
else:
raise AttributeError(f"{key} does not exist")

@staticmethod
def from_file(filename: str):
"""Reads a python file and constructs a corresponding :class:`Config` object.
Expand Down
8 changes: 5 additions & 3 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ def args_sanity_check():
"Please make sure you are using flash attention in cuda device."
)

if "mlp_layer_fusion" not in model:
model._add_item("mlp_layer_fusion", False)

if "MoE" in gpc.config.get("model_type", ModelType.INTERNLM.name):
if "num_experts" not in model:
model._add_item("num_experts", 1)
Expand All @@ -375,9 +378,8 @@ def args_sanity_check():
model._add_item("moe_type", "GShard")
if "moe_layer_kwargs" not in model:
model.moe_layer_kwargs = {}

if "mlp_layer_fusion" not in model:
model._add_item("mlp_layer_fusion", False)
if model.mlp_layer_fusion is False:
logger.warning("The config 'mlp_layer_fusion' is False, we recommend it should be set True when use MoE.")

# qk_interleaved config
if "qk_interleaved" not in gpc.config.model:
Expand Down
21 changes: 10 additions & 11 deletions internlm/model/modules/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,19 @@ def __init__(
"w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert
)

def forward(self, x):
if not self.mlp_layer_fusion:
w1_o = self.w1(x)
w3_o = self.w3(x)
else:
fussed_out = self.fused_w1_w3(x)
w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1)

if self.activation_type is ActivationType.swiglu.name:
out = self.w2(Silu(w1_o, w3_o))
self.activation_fn = Silu
else:
out = self.w2(Gelu(w1_o, w3_o))
self.activation_fn = Gelu

return out
def forward(self, x):
if self.mlp_layer_fusion:
fused_out = self.fused_w1_w3(x)
w1_o, w3_o = torch.split(fused_out, fused_out.shape[-1] // 2, dim=-1)
else:
w1_o = self.w1(x)
w3_o = self.w3(x)
return self.w2(self.activation_fn(w1_o, w3_o))


class GroupedFeedForward(nn.Module):
Expand Down
Loading