Skip to content

Commit

Permalink
Remove layer norm by fine-tuning.
Browse files Browse the repository at this point in the history
Implements a mostly drop-in replacement for LayerNorm,
called FakeLayerNorm, that eventually just uses a fixed
std value instead of the actual standard deviation.
  • Loading branch information
stefan-apollo committed Nov 17, 2024
1 parent 9755682 commit 1f6c25e
Show file tree
Hide file tree
Showing 4 changed files with 305 additions and 18 deletions.
35 changes: 35 additions & 0 deletions config/finetune_openwebtext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import time

eval_interval = 50
eval_iters = 50
log_interval = 1
wandb_log = True
wandb_project = 'stefan_remove_layer_norm'
name = 'gpt2-noLN'
wandb_run_name = name + '-' + str(time.time())
out_dir = f'out/{name}/'
dropout = 0
dataset = 'openwebtext'
init_from = 'gpt2'
remove_layer_norm = True

# only save checkpoints if the validation loss improves?
always_save_checkpoint = True

block_size = 1024
# OAI paper uses batch size 0.5M tokens (~2**19):
desired_batch_size = 2**19 / block_size
# What fits comfortably on my A100 without random crashes:
batch_size = 48
# Use gradient accumulation steps to match the effective batch size.
gradient_accumulation_steps = int(desired_batch_size // batch_size)

# Compiling didn't work for me.
compile = False

# We're not actually going all the way to 3000 iters.
learning_rate = 6e-4
warmup_iters = 100
decay_lr = True
lr_decay_iters = 2_000
max_iters = 3_000
102 changes: 86 additions & 16 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,47 @@
import torch.nn as nn
from torch.nn import functional as F

class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
from std_dicts import std_dict, std_bos_dict

def __init__(self, ndim, bias):
class FakeLayerNorm(nn.Module):
"""LayerNorm using a fixed std instead of the actual standard deviation."""
def __init__(self, ndim, layer, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
# Flag whether the LayerNorm is enabled ("real") or disabled ("fake")
self.mode = 'real'
self.attn_v_mode = 'real'
self.average_std = torch.ones(ndim, device='cuda') * std_dict[layer]
self.average_std[0] = std_bos_dict[layer]
self.bos_std = torch.ones(ndim, device='cuda') * std_bos_dict[layer]

def forward(self, input, std_type='avg', attn_v=False):
# We want all the enable / disable information to be in this class, but the class is re-used
# for both the QK and V paths. Thus we add the attn_v flag to the call that is True only for
# the V path. Thus we get to have flags `mode` and `attn_v_mode` to enable / disable the
# LN for the QK and V paths separately.
mode = self.attn_v_mode if attn_v else self.mode
if mode == 'fake':
# Which std values to use: We use (1) average std (which is actually a vector of length
# n_ctx for most of the time*) [a, b, b, ...] where a is the average std for position 1,
# and b is the average std for all other positions. We also have the option to use (2)
# the bos std [a, a, a, ...] for all positions, which we do if the input token is EOT.
# Note that we could differentiate between EOT and BOS, but I didn't need it here.
# *at the end (with disable_eot_std) we make the latter be like the former, and with
# disable_bos_std we make both vectors to be [b, b, b, ...], equivalent to scalars.
assert std_type in ['avg', 'bos']
std = self.average_std if std_type == 'avg' else self.bos_std
return (
(input - input.mean(-1, keepdim=True)) / std * self.weight + self.bias
if self.bias is not None
else input * self.weight
)
elif mode == 'real':
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
else:
raise ValueError(f'Unknown mode {mode}')

def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):

Expand All @@ -49,11 +80,14 @@ def __init__(self, config):
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))

def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
def forward(self, x_qk, x_v):
B, T, C = x_qk.size() # batch size, sequence length, embedding dimensionality (n_embd)

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
q, k, _ = self.c_attn(x_qk).split(self.n_embd, dim=2)
del _
_, _, v = self.c_attn(x_v).split(self.n_embd, dim=2)
del _
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
Expand Down Expand Up @@ -93,15 +127,24 @@ def forward(self, x):

class Block(nn.Module):

def __init__(self, config):
def __init__(self, config, layer_number):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.ln_1 = FakeLayerNorm(config.n_embd, layer=f'blocks.{layer_number}.hook_resid_pre', bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.ln_2 = FakeLayerNorm(config.n_embd, layer=f'blocks.{layer_number}.hook_resid_mid', bias=config.bias)
self.mlp = MLP(config)

def forward(self, x):
x = x + self.attn(self.ln_1(x))
def forward(self, x, eot_mask=None):
# Calculate LN'd (or FakeLN'd) x for Q and K
x_qk = self.ln_1(x)
# Calculate LN'd (or FakeLN'd) x for V (is switched at a different iteration step)
x_v = self.ln_1(x, attn_v=True)
# Calculate LN'd (or FakeLN'd) x for V if the input was EOT. Attention V seemed to be
# particularly sensitive to this (disabled at the end).
x_v_eot = self.ln_1(x, std_type='bos', attn_v=True)
x_v[eot_mask] = x_v_eot[eot_mask]
del x_v_eot
x = x + self.attn(x_qk, x_v)
x = x + self.mlp(self.ln_2(x))
return x

Expand All @@ -127,8 +170,8 @@ def __init__(self, config):
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
h = nn.ModuleList([Block(config, layer_number=layer_number) for layer_number in range(config.n_layer)]),
ln_f = FakeLayerNorm(config.n_embd, layer='blocks.11.hook_resid_post', bias=config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# with weight tying when using torch.compile() some warnings get generated:
Expand Down Expand Up @@ -173,12 +216,14 @@ def forward(self, idx, targets=None):
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

# Mask for EOT tokens to (temporarily) treat them differently when removing the LN.
eot_mask = idx == 50256
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = block(x, eot_mask=eot_mask)
x = self.transformer.ln_f(x)

if targets is not None:
Expand Down Expand Up @@ -328,3 +373,28 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
idx = torch.cat((idx, idx_next), dim=1)

return idx

def disable_ln_2(self, block_index):
self.transformer.h[block_index].ln_2.mode = 'fake'
print(f"disabled ln_2 for block {block_index}")

def disable_ln_1qk(self, block_index):
self.transformer.h[block_index].ln_1.mode = 'fake'
print(f"disabled ln_1qk for block {block_index}")

def disable_ln_1v(self, block_index):
self.transformer.h[block_index].ln_1.attn_v_mode = 'fake'
print(f"disabled ln_1v for block {block_index}")

def disable_ln_f(self):
self.transformer.ln_f.mode = 'fake'
print("disabled ln_f")

def disable_eot_std(self, block_index):
self.transformer.h[block_index].ln_1.bos_std = self.transformer.h[block_index].ln_1.average_std
print(f"disabled eot std for block {block_index}")

def disable_bos_std(self, block_index):
self.transformer.h[block_index].ln_1.average_std[0] = self.transformer.h[block_index].ln_1.average_std[1]
self.transformer.h[block_index].ln_1.bos_std[0] = self.transformer.h[block_index].ln_1.bos_std[1]
print(f"disabled bos std for block {block_index}")
77 changes: 77 additions & 0 deletions std_dicts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@

std_dict = {
'blocks.0.hook_resid_pre': 0.16597528755664825,
'blocks.0.hook_resid_mid': 1.2066363096237183,
'blocks.0.hook_resid_post': 1.8281980752944946,
'blocks.1.hook_resid_pre': 1.8281980752944946,
'blocks.1.hook_resid_mid': 1.9805375337600708,
'blocks.1.hook_resid_post': 2.0672824382781982,
'blocks.2.hook_resid_pre': 2.0672824382781982,
'blocks.2.hook_resid_mid': 2.1439383029937744,
'blocks.2.hook_resid_post': 2.21846604347229,
'blocks.3.hook_resid_pre': 2.21846604347229,
'blocks.3.hook_resid_mid': 2.3077552318573,
'blocks.3.hook_resid_post': 2.411747694015503,
'blocks.4.hook_resid_pre': 2.411747694015503,
'blocks.4.hook_resid_mid': 2.4973394870758057,
'blocks.4.hook_resid_post': 2.590416193008423,
'blocks.5.hook_resid_pre': 2.590416193008423,
'blocks.5.hook_resid_mid': 2.7218010425567627,
'blocks.5.hook_resid_post': 2.8397440910339355,
'blocks.6.hook_resid_pre': 2.8397440910339355,
'blocks.6.hook_resid_mid': 3.0089480876922607,
'blocks.6.hook_resid_post': 3.1704721450805664,
'blocks.7.hook_resid_pre': 3.1704721450805664,
'blocks.7.hook_resid_mid': 3.444382667541504,
'blocks.7.hook_resid_post': 3.697692632675171,
'blocks.8.hook_resid_pre': 3.697692632675171,
'blocks.8.hook_resid_mid': 4.003428936004639,
'blocks.8.hook_resid_post': 4.340140342712402,
'blocks.9.hook_resid_pre': 4.340140342712402,
'blocks.9.hook_resid_mid': 4.885100841522217,
'blocks.9.hook_resid_post': 5.460267066955566,
'blocks.10.hook_resid_pre': 5.460267066955566,
'blocks.10.hook_resid_mid': 6.5338945388793945,
'blocks.10.hook_resid_post': 7.905750751495361,
'blocks.11.hook_resid_pre': 7.905750751495361,
'blocks.11.hook_resid_mid': 14.712532043457031,
'blocks.11.hook_resid_post': 16.716890335083008,
}
std_bos_dict = {
'blocks.0.hook_resid_pre': 0.37046778202056885,
'blocks.0.hook_resid_mid': 1.0648562908172607,
'blocks.0.hook_resid_post': 5.1225905418396,
'blocks.1.hook_resid_pre': 5.1225905418396,
'blocks.1.hook_resid_mid': 6.165660381317139,
'blocks.1.hook_resid_post': 22.7987060546875,
'blocks.2.hook_resid_pre': 22.7987060546875,
'blocks.2.hook_resid_mid': 22.831886291503906,
'blocks.2.hook_resid_post': 92.71648406982422,
'blocks.3.hook_resid_pre': 92.71648406982422,
'blocks.3.hook_resid_mid': 92.6193618774414,
'blocks.3.hook_resid_post': 99.36549377441406,
'blocks.4.hook_resid_pre': 99.36549377441406,
'blocks.4.hook_resid_mid': 99.33074188232422,
'blocks.4.hook_resid_post': 104.73529815673828,
'blocks.5.hook_resid_pre': 104.73529815673828,
'blocks.5.hook_resid_mid': 104.69013214111328,
'blocks.5.hook_resid_post': 108.10015869140625,
'blocks.6.hook_resid_pre': 108.10015869140625,
'blocks.6.hook_resid_mid': 108.12938690185547,
'blocks.6.hook_resid_post': 110.11644744873047,
'blocks.7.hook_resid_pre': 110.11644744873047,
'blocks.7.hook_resid_mid': 110.13978576660156,
'blocks.7.hook_resid_post': 111.3284912109375,
'blocks.8.hook_resid_pre': 111.3284912109375,
'blocks.8.hook_resid_mid': 111.39237976074219,
'blocks.8.hook_resid_post': 112.07467651367188,
'blocks.9.hook_resid_pre': 112.07467651367188,
'blocks.9.hook_resid_mid': 112.13053131103516,
'blocks.9.hook_resid_post': 112.44728088378906,
'blocks.10.hook_resid_pre': 112.44728088378906,
'blocks.10.hook_resid_mid': 112.48235321044922,
'blocks.10.hook_resid_post': 112.49034881591797,
'blocks.11.hook_resid_pre': 112.49034881591797,
'blocks.11.hook_resid_mid': 17.505901336669922,
'blocks.11.hook_resid_post': 13.783772468566895,
}
Loading

0 comments on commit 1f6c25e

Please sign in to comment.