Skip to content

Commit

Permalink
Merge pull request #13 from stanford-crfm/openai-initialization
Browse files Browse the repository at this point in the history
OpenAI GPT-2 Initialization
  • Loading branch information
siddk authored Sep 15, 2021
2 parents 3a71d8d + 9fd657c commit 5984576
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
"""PyTorch OpenAI GPT-2 model."""

import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple
Expand Down Expand Up @@ -188,13 +189,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
bsz, num_heads, seq_len, dk = query.size()

# Preallocate attn_weights for `baddbmm`
attn_weights = torch.empty(
bsz * num_heads,
seq_len,
seq_len,
dtype=torch.float32,
device=query.device
)
attn_weights = torch.empty(bsz * num_heads, seq_len, seq_len, dtype=torch.float32, device=query.device)

# Compute Scale Factor
scale_factor = 1.0
Expand All @@ -207,13 +202,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
q, k = query.reshape(-1, seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, seq_len)
attn_weights = torch.baddbmm(
attn_weights,
q.float(),
k.float(),
beta=0,
alpha=scale_factor
)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, seq_len, seq_len)

else:
Expand Down Expand Up @@ -442,6 +431,17 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if "c_proj" in name and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))


@dataclass
class GPT2DoubleHeadsModelOutput(ModelOutput):
Expand Down Expand Up @@ -629,7 +629,7 @@ def __init__(self, config):
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i+1) for i in range(config.num_hidden_layers)])
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i + 1) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

self.init_weights()
Expand Down

0 comments on commit 5984576

Please sign in to comment.