-
Notifications
You must be signed in to change notification settings - Fork 18
/
model_ENET_SAD.py
689 lines (576 loc) · 26.4 KB
/
model_ENET_SAD.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
# Enet pytorch code retrieved from https://github.com/davidtvs/PyTorch-ENet/blob/master/models/enet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from utils.utils import mIoULoss, to_one_hot
class InitialBlock(nn.Module):
"""The initial block is composed of two branches:
1. a main branch which performs a regular convolution with stride 2;
2. an extension branch which performs max-pooling.
Doing both operations in parallel and concatenating their results
allows for efficient downsampling and expansion. The main branch
outputs 13 feature maps while the extension branch outputs 3, for a
total of 16 feature maps after concatenation.
Keyword arguments:
- in_channels (int): the number of input channels.
- out_channels (int): the number output channels.
- kernel_size (int, optional): the kernel size of the filters used in
the convolution layer. Default: 3.
- padding (int, optional): zero-padding added to both sides of the
input. Default: 0.
- bias (bool, optional): Adds a learnable bias to the output if
``True``. Default: False.
- relu (bool, optional): When ``True`` ReLU is used as the activation
function; otherwise, PReLU is used. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
bias=False,
relu=True):
super().__init__()
if relu:
activation = nn.ReLU
else:
activation = nn.PReLU
# Main branch - As stated above the number of output channels for this
# branch is the total minus 3, since the remaining channels come from
# the extension branch
self.main_branch = nn.Conv2d(
in_channels,
out_channels - 3,
kernel_size=3,
stride=2,
padding=1,
bias=bias)
# Extension branch
self.ext_branch = nn.MaxPool2d(3, stride=2, padding=1)
# Initialize batch normalization to be used after concatenation
self.batch_norm = nn.BatchNorm2d(out_channels)
# PReLU layer to apply after concatenating the branches
self.out_activation = activation()
def forward(self, x):
main = self.main_branch(x)
ext = self.ext_branch(x)
# Concatenate branches
out = torch.cat((main, ext), 1)
# Apply batch normalization
out = self.batch_norm(out)
return self.out_activation(out)
class RegularBottleneck(nn.Module):
"""Regular bottlenecks are the main building block of ENet.
Main branch:
1. Shortcut connection.
Extension branch:
1. 1x1 convolution which decreases the number of channels by
``internal_ratio``, also called a projection;
2. regular, dilated or asymmetric convolution;
3. 1x1 convolution which increases the number of channels back to
``channels``, also called an expansion;
4. dropout as a regularizer.
Keyword arguments:
- channels (int): the number of input and output channels.
- internal_ratio (int, optional): a scale factor applied to
``channels`` used to compute the number of
channels after the projection. eg. given ``channels`` equal to 128 and
internal_ratio equal to 2 the number of channels after the projection
is 64. Default: 4.
- kernel_size (int, optional): the kernel size of the filters used in
the convolution layer described above in item 2 of the extension
branch. Default: 3.
- padding (int, optional): zero-padding added to both sides of the
input. Default: 0.
- dilation (int, optional): spacing between kernel elements for the
convolution described in item 2 of the extension branch. Default: 1.
asymmetric (bool, optional): flags if the convolution described in
item 2 of the extension branch is asymmetric or not. Default: False.
- dropout_prob (float, optional): probability of an element to be
zeroed. Default: 0 (no dropout).
- bias (bool, optional): Adds a learnable bias to the output if
``True``. Default: False.
- relu (bool, optional): When ``True`` ReLU is used as the activation
function; otherwise, PReLU is used. Default: True.
"""
def __init__(self,
channels,
internal_ratio=4,
kernel_size=3,
padding=0,
dilation=1,
asymmetric=False,
dropout_prob=0,
bias=False,
relu=True):
super().__init__()
# Check in the internal_scale parameter is within the expected range
# [1, channels]
if internal_ratio <= 1 or internal_ratio > channels:
raise RuntimeError("Value out of range. Expected value in the "
"interval [1, {0}], got internal_scale={1}."
.format(channels, internal_ratio))
internal_channels = channels // internal_ratio
if relu:
activation = nn.ReLU
else:
activation = nn.PReLU
# Main branch - shortcut connection
# Extension branch - 1x1 convolution, followed by a regular, dilated or
# asymmetric convolution, followed by another 1x1 convolution, and,
# finally, a regularizer (spatial dropout). Number of channels is constant.
# 1x1 projection convolution
self.ext_conv1 = nn.Sequential(
nn.Conv2d(
channels,
internal_channels,
kernel_size=1,
stride=1,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
# If the convolution is asymmetric we split the main convolution in
# two. Eg. for a 5x5 asymmetric convolution we have two convolution:
# the first is 5x1 and the second is 1x5.
if asymmetric:
self.ext_conv2 = nn.Sequential(
nn.Conv2d(
internal_channels,
internal_channels,
kernel_size=(kernel_size, 1),
stride=1,
padding=(padding, 0),
dilation=dilation,
bias=bias), nn.BatchNorm2d(internal_channels), activation(),
nn.Conv2d(
internal_channels,
internal_channels,
kernel_size=(1, kernel_size),
stride=1,
padding=(0, padding),
dilation=dilation,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
else:
self.ext_conv2 = nn.Sequential(
nn.Conv2d(
internal_channels,
internal_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
# 1x1 expansion convolution
self.ext_conv3 = nn.Sequential(
nn.Conv2d(
internal_channels,
channels,
kernel_size=1,
stride=1,
bias=bias), nn.BatchNorm2d(channels), activation())
self.ext_regul = nn.Dropout2d(p=dropout_prob)
# PReLU layer to apply after adding the branches
self.out_activation = activation()
def forward(self, x):
# Main branch shortcut
main = x
# Extension branch
ext = self.ext_conv1(x)
ext = self.ext_conv2(ext)
ext = self.ext_conv3(ext)
ext = self.ext_regul(ext)
# Add main and extension branches
out = main + ext
return self.out_activation(out)
class DownsamplingBottleneck(nn.Module):
"""Downsampling bottlenecks further downsample the feature map size.
Main branch:
1. max pooling with stride 2; indices are saved to be used for
unpooling later.
Extension branch:
1. 2x2 convolution with stride 2 that decreases the number of channels
by ``internal_ratio``, also called a projection;
2. regular convolution (by default, 3x3);
3. 1x1 convolution which increases the number of channels to
``out_channels``, also called an expansion;
4. dropout as a regularizer.
Keyword arguments:
- in_channels (int): the number of input channels.
- out_channels (int): the number of output channels.
- internal_ratio (int, optional): a scale factor applied to ``channels``
used to compute the number of channels after the projection. eg. given
``channels`` equal to 128 and internal_ratio equal to 2 the number of
channels after the projection is 64. Default: 4.
- return_indices (bool, optional): if ``True``, will return the max
indices along with the outputs. Useful when unpooling later.
- dropout_prob (float, optional): probability of an element to be
zeroed. Default: 0 (no dropout).
- bias (bool, optional): Adds a learnable bias to the output if
``True``. Default: False.
- relu (bool, optional): When ``True`` ReLU is used as the activation
function; otherwise, PReLU is used. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
internal_ratio=4,
return_indices=False,
dropout_prob=0,
bias=False,
relu=True):
super().__init__()
# Store parameters that are needed later
self.return_indices = return_indices
# Check in the internal_scale parameter is within the expected range
# [1, channels]
if internal_ratio <= 1 or internal_ratio > in_channels:
raise RuntimeError("Value out of range. Expected value in the "
"interval [1, {0}], got internal_scale={1}. "
.format(in_channels, internal_ratio))
internal_channels = in_channels // internal_ratio
if relu:
activation = nn.ReLU
else:
activation = nn.PReLU
# Main branch - max pooling followed by feature map (channels) padding
self.main_max1 = nn.MaxPool2d(
2,
stride=2,
return_indices=return_indices)
# Extension branch - 2x2 convolution, followed by a regular, dilated or
# asymmetric convolution, followed by another 1x1 convolution. Number
# of channels is doubled.
# 2x2 projection convolution with stride 2
self.ext_conv1 = nn.Sequential(
nn.Conv2d(
in_channels,
internal_channels,
kernel_size=2,
stride=2,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
# Convolution
self.ext_conv2 = nn.Sequential(
nn.Conv2d(
internal_channels,
internal_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias), nn.BatchNorm2d(internal_channels), activation())
# 1x1 expansion convolution
self.ext_conv3 = nn.Sequential(
nn.Conv2d(
internal_channels,
out_channels,
kernel_size=1,
stride=1,
bias=bias), nn.BatchNorm2d(out_channels), activation())
self.ext_regul = nn.Dropout2d(p=dropout_prob)
# PReLU layer to apply after concatenating the branches
self.out_activation = activation()
def forward(self, x):
# Main branch shortcut
if self.return_indices:
main, max_indices = self.main_max1(x)
else:
main = self.main_max1(x)
# Extension branch
ext = self.ext_conv1(x)
ext = self.ext_conv2(ext)
ext = self.ext_conv3(ext)
ext = self.ext_regul(ext)
# Main branch channel padding
n, ch_ext, h, w = ext.size()
ch_main = main.size()[1]
padding = torch.zeros(n, ch_ext - ch_main, h, w)
# Before concatenating, check if main is on the CPU or GPU and
# convert padding accordingly
if main.is_cuda:
padding = padding.cuda()
# Concatenate
main = torch.cat((main, padding), 1)
# Add main and extension branches
out = main + ext
return self.out_activation(out), max_indices
class UpsamplingBottleneck(nn.Module):
"""The upsampling bottlenecks upsample the feature map resolution using max
pooling indices stored from the corresponding downsampling bottleneck.
Main branch:
1. 1x1 convolution with stride 1 that decreases the number of channels by
``internal_ratio``, also called a projection;
2. max unpool layer using the max pool indices from the corresponding
downsampling max pool layer.
Extension branch:
1. 1x1 convolution with stride 1 that decreases the number of channels by
``internal_ratio``, also called a projection;
2. transposed convolution (by default, 3x3);
3. 1x1 convolution which increases the number of channels to
``out_channels``, also called an expansion;
4. dropout as a regularizer.
Keyword arguments:
- in_channels (int): the number of input channels.
- out_channels (int): the number of output channels.
- internal_ratio (int, optional): a scale factor applied to ``in_channels``
used to compute the number of channels after the projection. eg. given
``in_channels`` equal to 128 and ``internal_ratio`` equal to 2 the number
of channels after the projection is 64. Default: 4.
- dropout_prob (float, optional): probability of an element to be zeroed.
Default: 0 (no dropout).
- bias (bool, optional): Adds a learnable bias to the output if ``True``.
Default: False.
- relu (bool, optional): When ``True`` ReLU is used as the activation
function; otherwise, PReLU is used. Default: True.
"""
def __init__(self,
in_channels,
out_channels,
internal_ratio=4,
dropout_prob=0,
bias=False,
relu=True):
super().__init__()
# Check in the internal_scale parameter is within the expected range
# [1, channels]
if internal_ratio <= 1 or internal_ratio > in_channels:
raise RuntimeError("Value out of range. Expected value in the "
"interval [1, {0}], got internal_scale={1}. "
.format(in_channels, internal_ratio))
internal_channels = in_channels // internal_ratio
if relu:
activation = nn.ReLU
else:
activation = nn.PReLU
# Main branch - max pooling followed by feature map (channels) padding
self.main_conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias),
nn.BatchNorm2d(out_channels))
# Remember that the stride is the same as the kernel_size, just like
# the max pooling layers
self.main_unpool1 = nn.MaxUnpool2d(kernel_size=2)
# Extension branch - 1x1 convolution, followed by a regular, dilated or
# asymmetric convolution, followed by another 1x1 convolution. Number
# of channels is doubled.
# 1x1 projection convolution with stride 1
self.ext_conv1 = nn.Sequential(
nn.Conv2d(
in_channels, internal_channels, kernel_size=1, bias=bias),
nn.BatchNorm2d(internal_channels), activation())
# Transposed convolution
self.ext_tconv1 = nn.ConvTranspose2d(
internal_channels,
internal_channels,
kernel_size=2,
stride=2,
bias=bias)
self.ext_tconv1_bnorm = nn.BatchNorm2d(internal_channels)
self.ext_tconv1_activation = activation()
# 1x1 expansion convolution
self.ext_conv2 = nn.Sequential(
nn.Conv2d(
internal_channels, out_channels, kernel_size=1, bias=bias),
nn.BatchNorm2d(out_channels), activation())
self.ext_regul = nn.Dropout2d(p=dropout_prob)
# PReLU layer to apply after concatenating the branches
self.out_activation = activation()
def forward(self, x, max_indices, output_size):
# Main branch shortcut
main = self.main_conv1(x)
main = self.main_unpool1(
main, max_indices, output_size=output_size)
# Extension branch
ext = self.ext_conv1(x)
ext = self.ext_tconv1(ext, output_size=output_size)
ext = self.ext_tconv1_bnorm(ext)
ext = self.ext_tconv1_activation(ext)
ext = self.ext_conv2(ext)
ext = self.ext_regul(ext)
# Add main and extension branches
out = main + ext
return self.out_activation(out)
class SpatialSoftmax(nn.Module):
def __init__(self, temperature=1, device='cpu'):
super(SpatialSoftmax, self).__init__()
if temperature:
self.temperature = Parameter(torch.ones(1) * temperature).to(device)
else:
self.temperature = 1.
def forward(self, feature):
feature = feature.view(feature.shape[0], -1, feature.shape[1] * feature.shape[2])
softmax_attention = F.softmax(feature / self.temperature, dim=-1)
return softmax_attention
class ENet_SAD(nn.Module):
"""Generate the ENet model.
Keyword arguments:
- num_classes (int): the number of classes to segment.
- encoder_relu (bool, optional): When ``True`` ReLU is used as the
activation function in the encoder blocks/layers; otherwise, PReLU
is used. Default: False.
- decoder_relu (bool, optional): When ``True`` ReLU is used as the
activation function in the decoder blocks/layers; otherwise, PReLU
is used. Default: True.
- sad (bool, optional): When ``True``, SAD is added to model
. If False, SAD is removed.
"""
def __init__(self, input_size, sad=False, encoder_relu=False, decoder_relu=True, weight_share=True, dataset='CULane'):
super().__init__()
# Init parameter
input_w, input_h = input_size
self.fc_input_feature = 5 * int(input_w / 16) * int(input_h / 16)
self.num_classes = 5 if dataset != 'BDD100K' else 1
self.scale_background = 0.4
# Loss scale factor for ENet w/o SAD
self.scale_seg = 1.0
self.scale_exist = 0.1
# Loss scale factor for ENet w SAD
self.scale_sad_seg = 1.0
self.scale_sad_iou = 0.1
self.scale_sad_exist = 0.1
self.scale_sad_distill = 0.1
# Loss function
self.dataset = dataset
if dataset != 'BDD100K':
self.ce_loss = nn.CrossEntropyLoss(weight=torch.tensor([self.scale_background, 1, 1, 1, 1]))
self.bce_loss = nn.BCELoss()
self.iou_loss = mIoULoss(n_classes=4)
else:
self.ce_loss = nn.BCEWithLogitsLoss()
self.bce_loss = nn.BCELoss()
self.iou_loss = mIoULoss(n_classes=1)
# Encoder generator
def get_encoder_block(n=2):
seq = nn.Sequential()
seq.add_module('regular%d_1' % n, RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu))
seq.add_module('dilated%d_2' % n, RegularBottleneck(128, dilation=2, padding=2, dropout_prob=0.1, relu=encoder_relu))
seq.add_module('asymmetric%d_3' % n, RegularBottleneck(128, kernel_size=5, padding=2, asymmetric=True, dropout_prob=0.1, relu=encoder_relu))
seq.add_module('dilated%d_4' % n, RegularBottleneck(128, dilation=4, padding=4, dropout_prob=0.1, relu=encoder_relu))
seq.add_module('regular%d_5' % n, RegularBottleneck(128, padding=1, dropout_prob=0.1, relu=encoder_relu))
seq.add_module('dilated%d_6' % n, RegularBottleneck(128, dilation=8, padding=8, dropout_prob=0.1, relu=encoder_relu))
seq.add_module('asymmetric%d_7' % n, RegularBottleneck(128, kernel_size=5, asymmetric=True, padding=2, dropout_prob=0.1, relu=encoder_relu))
seq.add_module('dilated%d_8' % n, RegularBottleneck(128, dilation=16, padding=16, dropout_prob=0.1, relu=encoder_relu))
return seq
# Stage 0 - Initial block
self.initial_block = InitialBlock(3, 16, relu=encoder_relu)
self.sad = sad
# Stage 1 - Encoder (E1)
self.downsample1 = DownsamplingBottleneck(16, 64, return_indices=True, dropout_prob=0.01, relu=encoder_relu)
self.encoder1 = nn.Sequential()
self.encoder1.add_module('regular1_1', RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu))
self.encoder1.add_module('regular1_2', RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu))
self.encoder1.add_module('regular1_3', RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu))
self.encoder1.add_module('regular1_4', RegularBottleneck(64, padding=1, dropout_prob=0.01, relu=encoder_relu))
# Shared Encoder (E2~E4)
# Stage 2 - Encoder (E2)
self.downsample2 = DownsamplingBottleneck(64, 128, return_indices=True, dropout_prob=0.1, relu=encoder_relu)
self.encoder2 = get_encoder_block(n=2)
# Stage 3 - Encoder (E3)
self.encoder3 = self.encoder2 if weight_share else get_encoder_block(3)
# Stage 4 - Encoder (E4)
self.encoder4 = self.encoder2 if weight_share else get_encoder_block(4)
# Stage 5 - Decoder (D1)
self.upsample4_0 = UpsamplingBottleneck(256, 64, dropout_prob=0.1, relu=decoder_relu)
self.regular4_1 = RegularBottleneck(64, padding=1, dropout_prob=0.1, relu=decoder_relu)
self.regular4_2 = RegularBottleneck(64, padding=1, dropout_prob=0.1, relu=decoder_relu)
# Stage 6 - Decoder (D2)
self.upsample5_0 = UpsamplingBottleneck(64, 16, dropout_prob=0.1, relu=decoder_relu)
self.regular5_1 = RegularBottleneck(16, padding=1, dropout_prob=0.1, relu=decoder_relu)
self.transposed_conv = nn.ConvTranspose2d(16, self.num_classes, kernel_size=3, stride=2, padding=1, bias=False)
# AT_GEN
if self.sad:
self.at_gen_upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.at_gen_l2_loss = nn.MSELoss(reduction='mean')
# Lane exist (P1)
self.exist = nn.Sequential(nn.Conv2d(128, 5, 1),
nn.Softmax(dim=1),
nn.AvgPool2d(2, 2),)
self.fc = nn.Sequential(nn.Linear(self.fc_input_feature, 128),
nn.ReLU(),
nn.Linear(128, 4),
nn.Sigmoid(),)
def at_gen(self, x1, x2):
"""
x1 - previous encoder step feature map
x2 - current encoder step feature map
"""
# G^2_sum
sps = SpatialSoftmax(device=x1.device)
if x1.size() != x2.size():
x1 = x1.pow(2).sum(dim=1)
x1 = sps(x1)
x2 = x2.pow(2).sum(dim=1, keepdim=True)
x2 = torch.squeeze(self.at_gen_upsample(x2), dim=1)
x2 = sps(x2)
else:
x1 = x1.pow(2).sum(dim=1)
x1 = sps(x1)
x2 = x2.pow(2).sum(dim=1)
x2 = sps(x2)
loss = self.at_gen_l2_loss(x1, x2)
return loss
def forward(self, img, seg_gt=None, exist_gt=None, sad_loss=False):
# Stage 0 - Initial block
input_size = img.size()
x_0 = self.initial_block(img)
# AT-GEN after each E2, E3, E4
# Stage 1 - Encoder (E1)
stage1_input_size = x_0.size()
x, max_indices1 = self.downsample1(x_0)
x_1 = self.encoder1(x)
# Stage 2 - Encoder (E2)
stage2_input_size = x_1.size()
x, max_indices2 = self.downsample2(x_1)
x_2 = self.encoder2(x)
if self.sad:
loss_2 = self.at_gen(x_1, x_2)
# Stage 3 - Encoder (E3)
x_3 = self.encoder3(x_2)
if self.sad:
loss_3 = self.at_gen(x_2, x_3)
# Stage 4 - Encoder (E4)
x_4 = self.encoder4(x_3)
if self.sad:
loss_4 = self.at_gen(x_3, x_4)
# Concatenate E3, E4
x_34 = torch.cat((x_3, x_4), dim=1)
# Stage 4 - Decoder (D1)
x = self.upsample4_0(x_34, max_indices2, output_size=stage2_input_size)
x = self.regular4_1(x)
x = self.regular4_2(x)
# Stage 5 - Decoder (D2)
x = self.upsample5_0(x, max_indices1, output_size=stage1_input_size)
x = self.regular5_1(x)
seg_pred = self.transposed_conv(x, output_size=input_size)
# lane exist
y = self.exist(x_4)
y = y.view(-1, self.fc_input_feature)
exist_pred = self.fc(y)
# loss calculation
if seg_gt is not None and exist_gt is not None:
# L = L_seg + a * L_iou + b * L_exist + c * L_distill
if self.sad:
if self.dataset != 'BDD100K':
loss_seg = self.ce_loss(seg_pred, seg_gt)
seg_gt_onehot = to_one_hot(seg_gt, 5)
else:
loss_seg = self.ce_loss(seg_pred.squeeze(dim=1), seg_gt.float())
seg_gt_onehot = seg_gt.unsqueeze(dim=1)
loss_iou = self.iou_loss(seg_pred, seg_gt_onehot)
loss_exist = self.bce_loss(exist_pred, exist_gt)
loss = loss_seg * self.scale_sad_seg + loss_iou * self.scale_sad_iou + loss_exist * self.scale_sad_exist
# Add SAD loss after 40K episodes
if sad_loss:
loss_distill = loss_2 + loss_3 + loss_4
loss += loss_distill * self.scale_sad_distill
else:
loss_seg = self.ce_loss(seg_pred, seg_gt)
loss_exist = self.bce_loss(exist_pred, exist_gt)
loss = loss_seg * self.scale_seg + loss_exist * self.scale_exist
else:
loss_seg = torch.tensor(0, dtype=img.dtype, device=img.device)
loss_exist = torch.tensor(0, dtype=img.dtype, device=img.device)
loss = torch.tensor(0, dtype=img.dtype, device=img.device)
return seg_pred, exist_pred, loss_seg, loss_exist, loss
if __name__ == '__main__':
tensor = torch.ones((8, 3, 288, 800)).cuda()
seg_gt = torch.zeros((8, 288, 800)).long().cuda()
exist_gt = torch.ones((8, 4)).cuda()
enet_sad = ENet_SAD((800, 288), sad=True, dataset='BDD100K')
enet_sad.cuda()
enet_sad.train(mode=True)
result = enet_sad(tensor, seg_gt, exist_gt, sad_loss=True)