Skip to content

Commit

Permalink
use rmsnorm instead
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 21, 2021
1 parent 03d0222 commit be80cf6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'transformer-in-transformer',
packages = find_packages(),
version = '0.0.9',
version = '0.1.0',
license='MIT',
description = 'Transformer in Transformer - Pytorch',
author = 'Phil Wang',
Expand Down
15 changes: 8 additions & 7 deletions transformer_in_transformer/tnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@ def unfold_output_size(image_size, kernel_size, stride, padding):

# classes

class ScaleNorm(nn.Module):
class RMSNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
self.g = nn.Parameter(torch.ones(dim))

def forward(self, x):
n = torch.norm(x, dim = -1, keepdim = True).clamp(min = self.eps)
return x / n * self.g
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = ScaleNorm(dim)
self.norm = RMSNorm(dim)
self.fn = fn

def forward(self, x, **kwargs):
Expand Down Expand Up @@ -136,7 +137,7 @@ def __init__(
for _ in range(depth):

pixel_to_patch = nn.Sequential(
ScaleNorm(pixel_dim),
RMSNorm(pixel_dim),
Rearrange('... n d -> ... (n d)'),
nn.Linear(pixel_dim * num_pixels, patch_dim),
)
Expand All @@ -152,7 +153,7 @@ def __init__(
self.layers = layers

self.mlp_head = nn.Sequential(
ScaleNorm(patch_dim),
RMSNorm(patch_dim),
nn.Linear(patch_dim, num_classes)
)

Expand Down

0 comments on commit be80cf6

Please sign in to comment.