forked from cxy1997/MNIST-baselines
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdrop_connect.py
53 lines (45 loc) · 2.06 KB
/
drop_connect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from __future__ import division, print_function
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.autograd import Variable
from utils import weights_init
class drop_connect_layer(nn.Module):
def __init__(self, in_features, out_features, prob=0.5, bias=True):
super(drop_connect_layer, self).__init__()
self.weight = Parameter(torch.zeros(out_features, in_features), requires_grad=True)
w_bound = np.sqrt(6. / (out_features + in_features))
self.weight.data.uniform_(-w_bound, w_bound)
self.weight_dropout = nn.Dropout(p=prob)
if bias:
self.bias = Parameter(torch.zeros(out_features), requires_grad=True)
self.bias.data.fill_(0)
self.bias_dropout = nn.Dropout(p=prob)
else:
self.register_parameter('bias', None)
def forward(self, x):
weight = self.weight_dropout(self.weight)
bias = self.bias_dropout(self.bias) if self.bias is not None else None
return F.linear(x, weight, bias)
class drop_connect_net(nn.Module):
def __init__(self, in_features=2025, classes=10, prob=0.5, bias=True):
super(drop_connect_net, self).__init__()
<<<<<<< HEAD
self.dc1 = nn.Linear(in_features, 600)#drop_connect_layer(in_features, 2000, prob=prob, bias=bias)
self.dc2 = nn.Linear(600, classes)#drop_connect_layer(20, classes, prob=prob, bias=bias)
def forward(self, x):
x = F.relu(self.dc1(x))
return self.dc2(x)
=======
self.dc1 = drop_connect_layer(in_features, 400, prob=prob, bias=bias)
self.dc3 = drop_connect_layer(400, 100, prob=prob, bias=bias)
self.dc4 = drop_connect_layer(100, 60, prob=prob, bias=bias)
self.dc5 = drop_connect_layer(60, classes, prob=prob, bias=bias)
def forward(self, x):
x = F.tanh(self.dc1(x))
x = F.tanh(self.dc3(x))
x = F.tanh(self.dc4(x))
return self.dc5(x)
>>>>>>> afe2e62aec5862afa648d612d31fcdc30583065c