-
Notifications
You must be signed in to change notification settings - Fork 55
/
model_mlp.py
35 lines (27 loc) · 874 Bytes
/
model_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import math
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, n_units, init_scale=1.0):
super(MLP, self).__init__()
self._n_units = copy.copy(n_units)
self._layers = []
for i in range(1, len(n_units)):
layer = nn.Linear(n_units[i-1], n_units[i], bias=False)
variance = math.sqrt(2.0 / (n_units[i-1] + n_units[i]))
layer.weight.data.normal_(0.0, init_scale * variance)
self._layers.append(layer)
name = 'fc%d' % i
if i == len(n_units) - 1:
name = 'fc' # the prediction layer is just called fc
self.add_module(name, layer)
def forward(self, x):
x = x.view(-1, self._n_units[0])
out = self._layers[0](x)
for layer in self._layers[1:]:
out = F.relu(out)
out = layer(out)
return out