-
Notifications
You must be signed in to change notification settings - Fork 10
/
layers.py
141 lines (111 loc) · 5.95 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['LambdaLayer', 'ScaledStdConv2d', 'HardBinaryScaledStdConv2d', 'LearnableBias','BinaryActivation', 'HardBinaryConv']
def get_weight(module):
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (module.weight - mean) / (std + module.eps)
return weight
# Calculate symmetric padding for a convolution
def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
return padding
class LambdaLayer(nn.Module):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class ScaledStdConv2d(nn.Conv2d):
"""Conv2d layer with Scaled Weight Standardization.
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
https://arxiv.org/abs/2101.08692
NOTE: the operations used in this impl differ slightly from the DeepMind Haiku impl. The impact is minor.
"""
def __init__(
self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=False, gamma=1.0, eps=1e-5, use_layernorm=False):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory to hijack LN kernel
def get_weight(self):
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
return self.gain * weight
def forward(self, x):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
class HardBinaryScaledStdConv2d(nn.Module):
def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1, gamma=1.0, eps=1e-5, use_layernorm=False):
super(HardBinaryScaledStdConv2d, self).__init__()
self.stride = stride
self.padding = padding
self.shape = (out_chn, in_chn, kernel_size, kernel_size)
self.weight = nn.Parameter(torch.rand(self.shape) * 0.001, requires_grad=True)
self.gain = nn.Parameter(torch.ones(out_chn, 1, 1, 1))
self.scale = gamma * self.weight[0].numel() ** -0.5
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm
def get_weight(self):
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
scaling_factor = torch.mean(torch.mean(torch.mean(abs(weight),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
scaling_factor = scaling_factor.detach()
binary_weights_no_grad = scaling_factor * torch.sign(weight)
cliped_weights = torch.clamp(weight, -1.0, 1.0)
binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
return self.gain * binary_weights
def forward(self, x):
return F.conv2d(x, self.get_weight(), stride=self.stride, padding=self.padding)
class LearnableBias(nn.Module):
def __init__(self, out_chn):
super(LearnableBias, self).__init__()
self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)
def forward(self, x):
out = x + self.bias.expand_as(x)
return out
class BinaryActivation(nn.Module):
def __init__(self):
super(BinaryActivation, self).__init__()
def forward(self, x):
out_forward = torch.sign(x)
#out_e1 = (x^2 + 2*x)
#out_e2 = (-x^2 + 2*x)
out_e_total = 0
mask1 = x < -1
mask2 = x < 0
mask3 = x < 1
out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
out = out_forward.detach() - out3.detach() + out3
return out
class HardBinaryConv(nn.Module):
def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1):
super(HardBinaryConv, self).__init__()
self.stride = stride
self.padding = padding
self.number_of_weights = in_chn * out_chn * kernel_size * kernel_size
self.shape = (out_chn, in_chn, kernel_size, kernel_size)
#self.weight = nn.Parameter(torch.rand((self.number_of_weights,1)) * 0.001, requires_grad=True)
self.weight = nn.Parameter(torch.rand((self.shape)) * 0.001, requires_grad=True)
def forward(self, x):
real_weights = self.weight
scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
scaling_factor = scaling_factor.detach()
binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
y = F.conv2d(x, binary_weights, stride=self.stride, padding=self.padding)
return y