-
Notifications
You must be signed in to change notification settings - Fork 29
/
model.py
323 lines (253 loc) · 10.8 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv3x3(in_planes, out_planes, stride=1, bias=False):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=bias)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride, bias=True)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, bias=True)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
#out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
#out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, downsample=None, use_bn=True):
super(Bottleneck, self).__init__()
bias = not use_bn
self.use_bn = use_bn
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=bias)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=bias)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.downsample = downsample
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.conv1(x)
if self.use_bn:
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
if self.use_bn:
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
if self.downsample is not None:
residual = self.downsample(x)
out = F.relu(residual + out)
return out
class BackBone(nn.Module):
def __init__(self, block, num_block, use_bn=True):
super(BackBone, self).__init__()
self.use_bn = use_bn
# Block 1
self.conv1 = conv3x3(36, 32)
self.conv2 = conv3x3(32, 32)
self.bn1 = nn.BatchNorm2d(32)
self.bn2 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
# Block 2
self.in_planes = 32
self.block2 = self._make_layer(block, 24, num_blocks=num_block[0])
self.block3 = self._make_layer(block, 48, num_blocks=num_block[1])
self.block4 = self._make_layer(block, 64, num_blocks=num_block[2])
self.block5 = self._make_layer(block, 96, num_blocks=num_block[3])
# Lateral layers
self.latlayer1 = nn.Conv2d(384, 196, kernel_size=1, stride=1, padding=0)
self.latlayer2 = nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0)
self.latlayer3 = nn.Conv2d(192, 96, kernel_size=1, stride=1, padding=0)
# Top-down layers
self.deconv1 = nn.ConvTranspose2d(196, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.deconv2 = nn.ConvTranspose2d(128, 96, kernel_size=3, stride=2, padding=1, output_padding=(1, 0))
def forward(self, x):
x = self.conv1(x)
if self.use_bn:
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
if self.use_bn:
x = self.bn2(x)
c1 = self.relu(x)
# bottom up layers
c2 = self.block2(c1)
c3 = self.block3(c2)
c4 = self.block4(c3)
c5 = self.block5(c4)
l5 = self.latlayer1(c5)
l4 = self.latlayer2(c4)
p5 = l4 + self.deconv1(l5)
l3 = self.latlayer3(c3)
p4 = l3 + self.deconv2(p5)
return p4
def _make_layer(self, block, planes, num_blocks):
if self.use_bn:
downsample = nn.Sequential(
nn.Conv2d(self.in_planes, planes * block.expansion,
kernel_size=1, stride=2, bias=False),
nn.BatchNorm2d(planes * block.expansion)
)
else:
downsample = nn.Conv2d(self.in_planes, planes * block.expansion,
kernel_size=1, stride=2, bias=True)
layers = []
layers.append(block(self.in_planes, planes, stride=2, downsample=downsample))
self.in_planes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(block(self.in_planes, planes, stride=1))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def _upsample_add(self, x, y):
'''Upsample and add two feature maps.
Args:
x: (Variable) top feature map to be upsampled.
y: (Variable) lateral feature map.
Returns:
(Variable) added feature map.
Note in PyTorch, when input size is odd, the upsampled feature map
with `F.upsample(..., scale_factor=2, mode='nearest')`
maybe not equal to the lateral feature map size.
e.g.
original input size: [N,_,15,15] ->
conv2d feature map size: [N,_,8,8] ->
upsampled feature map size: [N,_,16,16]
So we choose bilinear upsample which supports arbitrary output sizes.
'''
_, _, H, W = y.size()
return F.upsample(x, size=(H, W), mode='bilinear') + y
class Header(nn.Module):
def __init__(self, use_bn=True):
super(Header, self).__init__()
self.use_bn = use_bn
bias = not use_bn
self.conv1 = conv3x3(96, 96, bias=bias)
self.bn1 = nn.BatchNorm2d(96)
self.conv2 = conv3x3(96, 96, bias=bias)
self.bn2 = nn.BatchNorm2d(96)
self.conv3 = conv3x3(96, 96, bias=bias)
self.bn3 = nn.BatchNorm2d(96)
self.conv4 = conv3x3(96, 96, bias=bias)
self.bn4 = nn.BatchNorm2d(96)
self.clshead = conv3x3(96, 1, bias=True)
self.reghead = conv3x3(96, 6, bias=True)
def forward(self, x):
x = self.conv1(x)
if self.use_bn:
x = self.bn1(x)
x = self.conv2(x)
if self.use_bn:
x = self.bn2(x)
x = self.conv3(x)
if self.use_bn:
x = self.bn3(x)
x = self.conv4(x)
if self.use_bn:
x = self.bn4(x)
cls = torch.sigmoid(self.clshead(x))
reg = self.reghead(x)
return cls, reg
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
self.geometry = [-40.0, 40.0, 0.0, 70.0]
self.grid_size = 0.4
self.target_mean = [0.008, 0.001, 0.202, 0.2, 0.43, 1.368]
self.target_std_dev = [0.866, 0.5, 0.954, 0.668, 0.09, 0.111]
def forward(self, x):
'''
:param x: Tensor 6-channel geometry
6 channel map of [cos(yaw), sin(yaw), log(x), log(y), w, l]
Shape of x: (B, C=6, H=200, W=175)
:return: Concatenated Tensor of 8 channel geometry map of bounding box corners
8 channel are [rear_left_x, rear_left_y,
rear_right_x, rear_right_y,
front_right_x, front_right_y,
front_left_x, front_left_y]
Return tensor has a shape (B, C=8, H=200, W=175), and is located on the same device as x
'''
# Tensor in (B, C, H, W)
device = torch.device('cpu')
if x.is_cuda:
device = x.get_device()
for i in range(6):
x[:, i, :, :] = x[:, i, :, :] * self.target_std_dev[i] + self.target_mean[i]
cos_t, sin_t, dx, dy, log_w, log_l = torch.chunk(x, 6, dim=1)
theta = torch.atan2(sin_t, cos_t)
cos_t = torch.cos(theta)
sin_t = torch.sin(theta)
x = torch.arange(self.geometry[2], self.geometry[3], self.grid_size, dtype=torch.float32, device=device)
y = torch.arange(self.geometry[0], self.geometry[1], self.grid_size, dtype=torch.float32, device=device)
yy, xx = torch.meshgrid([y, x])
centre_y = yy + dy
centre_x = xx + dx
l = log_l.exp()
w = log_w.exp()
rear_left_x = centre_x - l/2 * cos_t - w/2 * sin_t
rear_left_y = centre_y - l/2 * sin_t + w/2 * cos_t
rear_right_x = centre_x - l/2 * cos_t + w/2 * sin_t
rear_right_y = centre_y - l/2 * sin_t - w/2 * cos_t
front_right_x = centre_x + l/2 * cos_t + w/2 * sin_t
front_right_y = centre_y + l/2 * sin_t - w/2 * cos_t
front_left_x = centre_x + l/2 * cos_t - w/2 * sin_t
front_left_y = centre_y + l/2 * sin_t + w/2 * cos_t
decoded_reg = torch.cat([rear_left_x, rear_left_y, rear_right_x, rear_right_y,
front_right_x, front_right_y, front_left_x, front_left_y], dim=1)
return decoded_reg
class PIXOR(nn.Module):
'''
The input of PIXOR nn module is a tensor of [batch_size, height, weight, channel]
The output of PIXOR nn module is also a tensor of [batch_size, height/4, weight/4, channel]
Note that we convert the dimensions to [C, H, W] for PyTorch's nn.Conv2d functions
'''
def __init__(self, use_bn=True, decode=False):
super(PIXOR, self).__init__()
self.backbone = BackBone(Bottleneck, [3, 6, 6, 3], use_bn)
self.header = Header(use_bn)
self.corner_decoder = Decoder()
self.use_decode = decode
def set_decode(self, decode):
self.use_decode = decode
def forward(self, x):
x = x.permute(0, 3, 1, 2)
# Torch Takes Tensor of shape (Batch_size, channels, height, width)
features = self.backbone(x)
cls, reg = self.header(features)
if self.use_decode:
reg = self.corner_decoder(reg)
# Return tensor(Batch_size, height, width, channels)
cls = cls.permute(0, 2, 3, 1)
reg = reg.permute(0, 2, 3, 1)
return torch.cat([cls, reg], dim=3)
def test():
print("Testing PIXOR model")
net = PIXOR(use_bn=False)
preds = net(torch.autograd.Variable(torch.randn(2, 800, 700, 36)))
print("Prediction output size", preds.size())
def test_decoder():
print("Testing PIXOR decoder")
net = PIXOR(use_bn=False)
net.set_decode(True)
preds = net(torch.autograd.Variable(torch.randn(2, 800, 700, 36)))
print("Predictions output size", preds.size())
if __name__ == "__main__":
test_decoder()