-
Notifications
You must be signed in to change notification settings - Fork 3
/
rms_norm.py
35 lines (26 loc) · 1.23 KB
/
rms_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import sys
import torch
from execution import runner
class RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, dtype, device):
super().__init__()
self.weight = torch.nn.Parameter(torch.randn(hidden_size, dtype=dtype, device=device), requires_grad=True)
self.variance_epsilon = 1e-6
def forward(self, x : torch.Tensor):
variance = (x * x).mean(2, keepdim=True) # Specifying -1 for reduction dimension causes an error in TorchScript
x_hat = x * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * x_hat
def input_func(steps, dtype, device) :
return [[torch.randn(128, 128, 1024, dtype=dtype, device=device)] for _ in range(steps)]
def grad_func(steps, dtype, device) :
return [torch.randn(128, 128, 1024, dtype=dtype, device=device) for _ in range(steps)]
class TestModule(torch.nn.Module) :
def __init__(self) :
super(TestModule, self).__init__()
self.norm = RMSNorm(1024, torch.float, 'cuda')
def forward(self, inputs) :
out1 = self.norm(inputs)
return (out1,)
from components.dummy_optimizer import optim_func
if __name__ == "__main__" :
runner.run(sys.argv, 'RMSNorm',TestModule(), optim_func, input_func, grad_func)