Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony committed Oct 8, 2024
1 parent afeff03 commit 98f0388
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 114 deletions.
7 changes: 4 additions & 3 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def forward(self, x, seq_dim=1):

class RotaryEmbedding(torch.nn.Module):
def __init__(
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False):
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False
):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs)
Expand Down Expand Up @@ -79,7 +80,7 @@ def _prepare_cache(self, seq_len, precision, base):

def get_emb(self):
return self.emb.to(self.precision).cuda()

def forward(self, x, seq_dim=0, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
Expand Down Expand Up @@ -253,4 +254,4 @@ def forward(self, x):
a.shape[0], 1, a.shape[2]
) # seq_len_k - 1 points to the last token index in the current inference batch.

return x + a
return x + a
38 changes: 17 additions & 21 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,6 @@ def __init__(
bias=neox_args.use_bias_in_attn_linear,
)


coeff = None
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
if self.apply_query_key_layer_scaling:
Expand Down Expand Up @@ -860,7 +859,7 @@ def gqa_project(self, hidden_states, attention_mask, layer_past=None):
return query_layer, key_layer, value_layer

def forward(self, hidden_states, attention_mask, layer_past=None):

# hidden_states: [sq, b, h]

# =====================
Expand Down Expand Up @@ -934,7 +933,6 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
query_layer = torch.cat((query_layer, query_pass), dim=-1)
key_layer = torch.cat((key_layer, key_pass), dim=-1)


# ==================================
# Cache key and value for inference
# ==================================
Expand Down Expand Up @@ -1030,16 +1028,17 @@ def __init__(
# Self attention.
if neox_args.te_mha or neox_args.te_fp8_mha:
from megatron.model.transformer_engine import TEMultiheadAttention

self.attention = TEMultiheadAttention(
neox_args=neox_args,
attention_mask_func=attention_mask_func,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
rpe=rpe,
use_cache=self.use_cache,
rotary=rotary,
parallel_output=self.gpt_j_residual,
neox_args=neox_args,
attention_mask_func=attention_mask_func,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
rpe=rpe,
use_cache=self.use_cache,
rotary=rotary,
parallel_output=self.gpt_j_residual,
)

else:
Expand Down Expand Up @@ -1073,6 +1072,7 @@ def get_mlp(**kw):

def get_te_lnmlp(**kw):
from megatron.model.transformer_engine import TELayerNormMLP

return TELayerNormMLP(
neox_args=neox_args,
init_method=init_method,
Expand Down Expand Up @@ -1201,18 +1201,16 @@ def forward(self, x, attention_mask, layer_past=None):
bias_dropout_fn = self._get_bias_dropout()
moe_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
# x: [b, s, h]


#Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer.

# Enable delayedscaling if TransformerEngine's FP8 is used for MHA layer.
if self.neox_args.te_fp8_mha:
from megatron.model.transformer_engine import TEDelayedScaling

fp8_recipe = TEDelayedScaling(
neox_args=self.neox_args
)
fp8_recipe = TEDelayedScaling(neox_args=self.neox_args)
fp8_context = fp8_recipe.get_context()
else:
from contextlib import nullcontext

fp8_context = nullcontext()

with fp8_context:
Expand Down Expand Up @@ -1319,9 +1317,7 @@ def forward(self, x, attention_mask, layer_past=None):
else:
if self.moe_type == "deepspeed":
mlp_output, moe_loss, _ = self.mlp(layernorm_output)
mlp_bias = (
None # deepspeed.moe.layer.MoE.forward ignores the bias term
)
mlp_bias = None # deepspeed.moe.layer.MoE.forward ignores the bias term
elif self.moe_type == "megablocks":
mlp_output, mlp_bias = self.mlp(layernorm_output)
else:
Expand Down
Loading

0 comments on commit 98f0388

Please sign in to comment.