From be80cf6943813ef918e2a69cf298d371b9f3e830 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 21 Mar 2021 16:05:58 -0700 Subject: [PATCH] use rmsnorm instead --- setup.py | 2 +- transformer_in_transformer/tnt.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index 4b0e070..d0b028b 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'transformer-in-transformer', packages = find_packages(), - version = '0.0.9', + version = '0.1.0', license='MIT', description = 'Transformer in Transformer - Pytorch', author = 'Phil Wang', diff --git a/transformer_in_transformer/tnt.py b/transformer_in_transformer/tnt.py index 0833933..02ab56c 100644 --- a/transformer_in_transformer/tnt.py +++ b/transformer_in_transformer/tnt.py @@ -21,20 +21,21 @@ def unfold_output_size(image_size, kernel_size, stride, padding): # classes -class ScaleNorm(nn.Module): +class RMSNorm(nn.Module): def __init__(self, dim, eps = 1e-5): super().__init__() + self.scale = dim ** -0.5 self.eps = eps - self.g = nn.Parameter(torch.ones(1)) + self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): - n = torch.norm(x, dim = -1, keepdim = True).clamp(min = self.eps) - return x / n * self.g + norm = torch.norm(x, dim = -1, keepdim = True) * self.scale + return x / norm.clamp(min = self.eps) * self.g class PreNorm(nn.Module): def __init__(self, dim, fn): super().__init__() - self.norm = ScaleNorm(dim) + self.norm = RMSNorm(dim) self.fn = fn def forward(self, x, **kwargs): @@ -136,7 +137,7 @@ def __init__( for _ in range(depth): pixel_to_patch = nn.Sequential( - ScaleNorm(pixel_dim), + RMSNorm(pixel_dim), Rearrange('... n d -> ... (n d)'), nn.Linear(pixel_dim * num_pixels, patch_dim), ) @@ -152,7 +153,7 @@ def __init__( self.layers = layers self.mlp_head = nn.Sequential( - ScaleNorm(patch_dim), + RMSNorm(patch_dim), nn.Linear(patch_dim, num_classes) )