-
Notifications
You must be signed in to change notification settings - Fork 44
/
multibox_loss.py
799 lines (627 loc) · 34.4 KB
/
multibox_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.autograd import Variable
from ..box_utils import match, log_sum_exp, decode, center_size, crop, elemwise_mask_iou, elemwise_box_iou
from data import cfg, mask_type, activation_func
def ciou(bboxes1, bboxes2):
bboxes1 = torch.sigmoid(bboxes1)
bboxes2 = torch.sigmoid(bboxes2)
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
cious = torch.zeros((rows, cols))
if rows * cols == 0:
return cious
exchange = False
if bboxes1.shape[0] > bboxes2.shape[0]:
bboxes1, bboxes2 = bboxes2, bboxes1
cious = torch.zeros((cols, rows))
exchange = True
w1 = torch.exp(bboxes1[:, 2])
h1 = torch.exp(bboxes1[:, 3])
w2 = torch.exp(bboxes2[:, 2])
h2 = torch.exp(bboxes2[:, 3])
area1 = w1 * h1
area2 = w2 * h2
center_x1 = bboxes1[:, 0]
center_y1 = bboxes1[:, 1]
center_x2 = bboxes2[:, 0]
center_y2 = bboxes2[:, 1]
inter_l = torch.max(center_x1 - w1 / 2,center_x2 - w2 / 2)
inter_r = torch.min(center_x1 + w1 / 2,center_x2 + w2 / 2)
inter_t = torch.max(center_y1 - h1 / 2,center_y2 - h2 / 2)
inter_b = torch.min(center_y1 + h1 / 2,center_y2 + h2 / 2)
inter_area = torch.clamp((inter_r - inter_l),min=0) * torch.clamp((inter_b - inter_t),min=0)
c_l = torch.min(center_x1 - w1 / 2,center_x2 - w2 / 2)
c_r = torch.max(center_x1 + w1 / 2,center_x2 + w2 / 2)
c_t = torch.min(center_y1 - h1 / 2,center_y2 - h2 / 2)
c_b = torch.max(center_y1 + h1 / 2,center_y2 + h2 / 2)
inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2
c_diag = torch.clamp((c_r - c_l),min=0)**2 + torch.clamp((c_b - c_t),min=0)**2
union = area1+area2-inter_area
u = (inter_diag) / c_diag
iou = inter_area / union
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(w2 / h2) - torch.atan(w1 / h1)), 2)
with torch.no_grad():
S = (iou>0.5).float()
alpha= S*v/(1-iou+v)
cious = iou - u - alpha * v
cious = torch.clamp(cious,min=-1.0,max = 1.0)
if exchange:
cious = cious.T
return torch.sum(1-cious)
def diou(bboxes1, bboxes2):
bboxes1 = torch.sigmoid(bboxes1)
bboxes2 = torch.sigmoid(bboxes2)
rows = bboxes1.shape[0]
cols = bboxes2.shape[0]
cious = torch.zeros((rows, cols))
if rows * cols == 0:
return cious
exchange = False
if bboxes1.shape[0] > bboxes2.shape[0]:
bboxes1, bboxes2 = bboxes2, bboxes1
cious = torch.zeros((cols, rows))
exchange = True
w1 = torch.exp(bboxes1[:, 2])
h1 = torch.exp(bboxes1[:, 3])
w2 = torch.exp(bboxes2[:, 2])
h2 = torch.exp(bboxes2[:, 3])
area1 = w1 * h1
area2 = w2 * h2
center_x1 = bboxes1[:, 0]
center_y1 = bboxes1[:, 1]
center_x2 = bboxes2[:, 0]
center_y2 = bboxes2[:, 1]
inter_l = torch.max(center_x1 - w1 / 2,center_x2 - w2 / 2)
inter_r = torch.min(center_x1 + w1 / 2,center_x2 + w2 / 2)
inter_t = torch.max(center_y1 - h1 / 2,center_y2 - h2 / 2)
inter_b = torch.min(center_y1 + h1 / 2,center_y2 + h2 / 2)
inter_area = torch.clamp((inter_r - inter_l),min=0) * torch.clamp((inter_b - inter_t),min=0)
c_l = torch.min(center_x1 - w1 / 2,center_x2 - w2 / 2)
c_r = torch.max(center_x1 + w1 / 2,center_x2 + w2 / 2)
c_t = torch.min(center_y1 - h1 / 2,center_y2 - h2 / 2)
c_b = torch.max(center_y1 + h1 / 2,center_y2 + h2 / 2)
inter_diag = (center_x2 - center_x1)**2 + (center_y2 - center_y1)**2
c_diag = torch.clamp((c_r - c_l),min=0)**2 + torch.clamp((c_b - c_t),min=0)**2
union = area1+area2-inter_area
u = (inter_diag) / c_diag
iou = inter_area / union
dious = iou - u
dious = torch.clamp(dious,min=-1.0,max = 1.0)
if exchange:
dious = dious.T
return torch.sum(1-dious)
class MultiBoxLoss(nn.Module):
"""SSD Weighted Loss Function
Compute Targets:
1) Produce Confidence Target Indices by matching ground truth boxes
with (default) 'priorboxes' that have jaccard index > threshold parameter
(default threshold: 0.5).
2) Produce localization target by 'encoding' variance into offsets of ground
truth boxes and their matched 'priorboxes'.
3) Hard negative mining to filter the excessive number of negative examples
that comes with using a large number of default bounding boxes.
(default negative:positive ratio 3:1)
Objective Loss:
L(x,c,l,g) = (Lconf(x, c) + 伪Lloc(x,l,g)) / N
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
weighted by 伪 which is set to 1 by cross val.
Args:
c: class confidences,
l: predicted boxes,
g: ground truth boxes
N: number of matched default boxes
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
"""
def __init__(self, num_classes, pos_threshold, neg_threshold, negpos_ratio):
super(MultiBoxLoss, self).__init__()
self.num_classes = num_classes
self.pos_threshold = pos_threshold
self.neg_threshold = neg_threshold
self.negpos_ratio = negpos_ratio
# If you output a proto mask with this area, your l1 loss will be l1_alpha
# Note that the area is relative (so 1 would be the entire image)
self.l1_expected_area = 20*20/70/70
self.l1_alpha = 0.1
if cfg.use_class_balanced_conf:
self.class_instances = None
self.total_instances = 0
def forward(self, net, predictions, targets, masks, num_crowds):
"""Multibox Loss
Args:
predictions (tuple): A tuple containing loc preds, conf preds,
mask preds, and prior boxes from SSD net.
loc shape: torch.size(batch_size,num_priors,4)
conf shape: torch.size(batch_size,num_priors,num_classes)
masks shape: torch.size(batch_size,num_priors,mask_dim)
priors shape: torch.size(num_priors,4)
proto* shape: torch.size(batch_size,mask_h,mask_w,mask_dim)
targets (list<tensor>): Ground truth boxes and labels for a batch,
shape: [batch_size][num_objs,5] (last idx is the label).
masks (list<tensor>): Ground truth masks for each object in each image,
shape: [batch_size][num_objs,im_height,im_width]
num_crowds (list<int>): Number of crowd annotations per batch. The crowd
annotations should be the last num_crowds elements of targets and masks.
* Only if mask_type == lincomb
"""
loc_data = predictions['loc']
conf_data = predictions['conf']
mask_data = predictions['mask']
priors = predictions['priors']
if cfg.mask_type == mask_type.lincomb:
proto_data = predictions['proto']
score_data = predictions['score'] if cfg.use_mask_scoring else None
inst_data = predictions['inst'] if cfg.use_instance_coeff else None
labels = [None] * len(targets) # Used in sem segm loss
batch_size = loc_data.size(0)
num_priors = priors.size(0)
num_classes = self.num_classes
# Match priors (default boxes) and ground truth boxes
# These tensors will be created with the same device as loc_data
loc_t = loc_data.new(batch_size, num_priors, 4)
gt_box_t = loc_data.new(batch_size, num_priors, 4)
conf_t = loc_data.new(batch_size, num_priors).long()
idx_t = loc_data.new(batch_size, num_priors).long()
if cfg.use_class_existence_loss:
class_existence_t = loc_data.new(batch_size, num_classes-1)
for idx in range(batch_size):
truths = targets[idx][:, :-1].data
labels[idx] = targets[idx][:, -1].data.long()
if cfg.use_class_existence_loss:
# Construct a one-hot vector for each object and collapse it into an existence vector with max
# Also it's fine to include the crowd annotations here
class_existence_t[idx, :] = torch.eye(num_classes-1, device=conf_t.get_device())[labels[idx]].max(dim=0)[0]
# Split the crowd annotations because they come bundled in
cur_crowds = num_crowds[idx]
if cur_crowds > 0:
split = lambda x: (x[-cur_crowds:], x[:-cur_crowds])
crowd_boxes, truths = split(truths)
# We don't use the crowd labels or masks
_, labels[idx] = split(labels[idx])
_, masks[idx] = split(masks[idx])
else:
crowd_boxes = None
match(self.pos_threshold, self.neg_threshold,
truths, priors.data, labels[idx], crowd_boxes,
loc_t, conf_t, idx_t, idx, loc_data[idx])
gt_box_t[idx, :, :] = truths[idx_t[idx]]
# wrap targets
loc_t = Variable(loc_t, requires_grad=False)
conf_t = Variable(conf_t, requires_grad=False)
idx_t = Variable(idx_t, requires_grad=False)
pos = conf_t > 0
num_pos = pos.sum(dim=1, keepdim=True)
# Shape: [batch,num_priors,4]
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
losses = {}
# Localization Loss (Smooth L1)
if cfg.train_boxes:
loc_p = loc_data[pos_idx].view(-1, 4)
loc_t = loc_t[pos_idx].view(-1, 4)
if cfg.reg_loss == 'ciou':
losses['B'] = ciou(loc_p, loc_t) * cfg.bbox_alpha * 5
else:
if cfg.reg_loss == 'sl1':
losses['B'] = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') * cfg.bbox_alpha
else:
raise AssertionError("Currently, bbox regression surports 'ciou' or 'sl1'.")
if cfg.train_masks:
if cfg.mask_type == mask_type.direct:
if cfg.use_gt_bboxes:
pos_masks = []
for idx in range(batch_size):
pos_masks.append(masks[idx][idx_t[idx, pos[idx]]])
masks_t = torch.cat(pos_masks, 0)
masks_p = mask_data[pos, :].view(-1, cfg.mask_dim)
losses['M'] = F.binary_cross_entropy(torch.clamp(masks_p, 0, 1), masks_t, reduction='sum') * cfg.mask_alpha
else:
losses['M'] = self.direct_mask_loss(pos_idx, idx_t, loc_data, mask_data, priors, masks)
elif cfg.mask_type == mask_type.lincomb:
ret = self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels)
if cfg.use_maskiou:
loss, maskiou_targets = ret
else:
loss = ret
losses.update(loss)
if cfg.mask_proto_loss is not None:
if cfg.mask_proto_loss == 'l1':
losses['P'] = torch.mean(torch.abs(proto_data)) / self.l1_expected_area * self.l1_alpha
elif cfg.mask_proto_loss == 'disj':
losses['P'] = -torch.mean(torch.max(F.log_softmax(proto_data, dim=-1), dim=-1)[0])
# Confidence loss
if cfg.use_focal_loss:
if cfg.use_sigmoid_focal_loss:
losses['C'] = self.focal_conf_sigmoid_loss(conf_data, conf_t)
elif cfg.use_objectness_score:
losses['C'] = self.focal_conf_objectness_loss(conf_data, conf_t)
else:
losses['C'] = self.focal_conf_loss(conf_data, conf_t)
else:
if cfg.use_objectness_score:
losses['C'] = self.conf_objectness_loss(conf_data, conf_t, batch_size, loc_p, loc_t, priors)
else:
losses['C'] = self.ohem_conf_loss(conf_data, conf_t, pos, batch_size)
# Mask IoU Loss
if cfg.use_maskiou and maskiou_targets is not None:
losses['I'] = self.mask_iou_loss(net, maskiou_targets)
# These losses also don't depend on anchors
if cfg.use_class_existence_loss:
losses['E'] = self.class_existence_loss(predictions['classes'], class_existence_t)
if cfg.use_semantic_segmentation_loss:
losses['S'] = self.semantic_segmentation_loss(predictions['segm'], masks, labels)
# Divide all losses by the number of positives.
# Don't do it for loss[P] because that doesn't depend on the anchors.
total_num_pos = num_pos.data.sum().float()
for k in losses:
if k not in ('P', 'E', 'S'):
losses[k] /= total_num_pos
else:
losses[k] /= batch_size
# Loss Key:
# - B: Box Localization Loss
# - C: Class Confidence Loss
# - M: Mask Loss
# - P: Prototype Loss
# - D: Coefficient Diversity Loss
# - E: Class Existence Loss
# - S: Semantic Segmentation Loss
return losses
def class_existence_loss(self, class_data, class_existence_t):
return cfg.class_existence_alpha * F.binary_cross_entropy_with_logits(class_data, class_existence_t, reduction='sum')
def semantic_segmentation_loss(self, segment_data, mask_t, class_t, interpolation_mode='bilinear'):
# Note num_classes here is without the background class so cfg.num_classes-1
batch_size, num_classes, mask_h, mask_w = segment_data.size()
loss_s = 0
for idx in range(batch_size):
cur_segment = segment_data[idx]
cur_class_t = class_t[idx]
with torch.no_grad():
downsampled_masks = F.interpolate(mask_t[idx].unsqueeze(0), (mask_h, mask_w),
mode=interpolation_mode, align_corners=False).squeeze(0)
downsampled_masks = downsampled_masks.gt(0.5).float()
# Construct Semantic Segmentation
segment_t = torch.zeros_like(cur_segment, requires_grad=False)
for obj_idx in range(downsampled_masks.size(0)):
segment_t[cur_class_t[obj_idx]] = torch.max(segment_t[cur_class_t[obj_idx]], downsampled_masks[obj_idx])
loss_s += F.binary_cross_entropy_with_logits(cur_segment, segment_t, reduction='sum')
return loss_s / mask_h / mask_w * cfg.semantic_segmentation_alpha
def ohem_conf_loss(self, conf_data, conf_t, pos, num):
# Compute max conf across batch for hard negative mining
batch_conf = conf_data.view(-1, self.num_classes)
if cfg.ohem_use_most_confident:
# i.e. max(softmax) along classes > 0
batch_conf = F.softmax(batch_conf, dim=1)
loss_c, _ = batch_conf[:, 1:].max(dim=1)
else:
# i.e. -softmax(class 0 confidence)
loss_c = log_sum_exp(batch_conf) - batch_conf[:, 0]
# Hard Negative Mining
loss_c = loss_c.view(num, -1)
loss_c[pos] = 0 # filter out pos boxes
loss_c[conf_t < 0] = 0 # filter out neutrals (conf_t = -1)
_, loss_idx = loss_c.sort(1, descending=True)
_, idx_rank = loss_idx.sort(1)
num_pos = pos.long().sum(1, keepdim=True)
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
neg = idx_rank < num_neg.expand_as(idx_rank)
# Just in case there aren't enough negatives, don't start using positives as negatives
neg[pos] = 0
neg[conf_t < 0] = 0 # Filter out neutrals
# Confidence Loss Including Positive and Negative Examples
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
targets_weighted = conf_t[(pos+neg).gt(0)]
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='none')
if cfg.use_class_balanced_conf:
# Lazy initialization
if self.class_instances is None:
self.class_instances = torch.zeros(self.num_classes, device=targets_weighted.device)
classes, counts = targets_weighted.unique(return_counts=True)
for _cls, _cnt in zip(classes.cpu().numpy(), counts.cpu().numpy()):
self.class_instances[_cls] += _cnt
self.total_instances += targets_weighted.size(0)
weighting = 1 - (self.class_instances[targets_weighted] / self.total_instances)
weighting = torch.clamp(weighting, min=1/self.num_classes)
# If you do the math, the average weight of self.class_instances is this
avg_weight = (self.num_classes - 1) / self.num_classes
loss_c = (loss_c * weighting).sum() / avg_weight
else:
loss_c = loss_c.sum()
return cfg.conf_alpha * loss_c
def focal_conf_loss(self, conf_data, conf_t):
"""
Focal loss as described in https://arxiv.org/pdf/1708.02002.pdf
Adapted from https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py
Note that this uses softmax and not the original sigmoid from the paper.
"""
conf_t = conf_t.view(-1) # [batch_size*num_priors]
conf_data = conf_data.view(-1, conf_data.size(-1)) # [batch_size*num_priors, num_classes]
# Ignore neutral samples (class < 0)
keep = (conf_t >= 0).float()
conf_t[conf_t < 0] = 0 # so that gather doesn't drum up a fuss
logpt = F.log_softmax(conf_data, dim=-1)
logpt = logpt.gather(1, conf_t.unsqueeze(-1))
logpt = logpt.view(-1)
pt = logpt.exp()
# I adapted the alpha_t calculation here from
# https://github.com/pytorch/pytorch/blob/master/modules/detectron/softmax_focal_loss_op.cu
# You'd think you want all the alphas to sum to one, but in the original implementation they
# just give background an alpha of 1-alpha and each forground an alpha of alpha.
background = (conf_t == 0).float()
at = (1 - cfg.focal_loss_alpha) * background + cfg.focal_loss_alpha * (1 - background)
loss = -at * (1 - pt) ** cfg.focal_loss_gamma * logpt
# See comment above for keep
return cfg.conf_alpha * (loss * keep).sum()
def focal_conf_sigmoid_loss(self, conf_data, conf_t):
"""
Focal loss but using sigmoid like the original paper.
Note: To make things mesh easier, the network still predicts 81 class confidences in this mode.
Because retinanet originally only predicts 80, we simply just don't use conf_data[..., 0]
"""
num_classes = conf_data.size(-1)
conf_t = conf_t.view(-1) # [batch_size*num_priors]
conf_data = conf_data.view(-1, num_classes) # [batch_size*num_priors, num_classes]
# Ignore neutral samples (class < 0)
keep = (conf_t >= 0).float()
conf_t[conf_t < 0] = 0 # can't mask with -1, so filter that out
# Compute a one-hot embedding of conf_t
# From https://github.com/kuangliu/pytorch-retinanet/blob/master/utils.py
conf_one_t = torch.eye(num_classes, device=conf_t.get_device())[conf_t]
conf_pm_t = conf_one_t * 2 - 1 # -1 if background, +1 if forground for specific class
logpt = F.logsigmoid(conf_data * conf_pm_t) # note: 1 - sigmoid(x) = sigmoid(-x)
pt = logpt.exp()
at = cfg.focal_loss_alpha * conf_one_t + (1 - cfg.focal_loss_alpha) * (1 - conf_one_t)
at[..., 0] = 0 # Set alpha for the background class to 0 because sigmoid focal loss doesn't use it
loss = -at * (1 - pt) ** cfg.focal_loss_gamma * logpt
loss = keep * loss.sum(dim=-1)
return cfg.conf_alpha * loss.sum()
def focal_conf_objectness_loss(self, conf_data, conf_t):
"""
Instead of using softmax, use class[0] to be the objectness score and do sigmoid focal loss on that.
Then for the rest of the classes, softmax them and apply CE for only the positive examples.
If class[0] = 1 implies forground and class[0] = 0 implies background then you achieve something
similar during test-time to softmax by setting class[1:] = softmax(class[1:]) * class[0] and invert class[0].
"""
conf_t = conf_t.view(-1) # [batch_size*num_priors]
conf_data = conf_data.view(-1, conf_data.size(-1)) # [batch_size*num_priors, num_classes]
# Ignore neutral samples (class < 0)
keep = (conf_t >= 0).float()
conf_t[conf_t < 0] = 0 # so that gather doesn't drum up a fuss
background = (conf_t == 0).float()
at = (1 - cfg.focal_loss_alpha) * background + cfg.focal_loss_alpha * (1 - background)
logpt = F.logsigmoid(conf_data[:, 0]) * (1 - background) + F.logsigmoid(-conf_data[:, 0]) * background
pt = logpt.exp()
obj_loss = -at * (1 - pt) ** cfg.focal_loss_gamma * logpt
# All that was the objectiveness loss--now time for the class confidence loss
pos_mask = conf_t > 0
conf_data_pos = (conf_data[:, 1:])[pos_mask] # Now this has just 80 classes
conf_t_pos = conf_t[pos_mask] - 1 # So subtract 1 here
class_loss = F.cross_entropy(conf_data_pos, conf_t_pos, reduction='sum')
return cfg.conf_alpha * (class_loss + (obj_loss * keep).sum())
def conf_objectness_loss(self, conf_data, conf_t, batch_size, loc_p, loc_t, priors):
"""
Instead of using softmax, use class[0] to be p(obj) * p(IoU) as in YOLO.
Then for the rest of the classes, softmax them and apply CE for only the positive examples.
"""
conf_t = conf_t.view(-1) # [batch_size*num_priors]
conf_data = conf_data.view(-1, conf_data.size(-1)) # [batch_size*num_priors, num_classes]
pos_mask = (conf_t > 0)
neg_mask = (conf_t == 0)
obj_data = conf_data[:, 0]
obj_data_pos = obj_data[pos_mask]
obj_data_neg = obj_data[neg_mask]
# Don't be confused, this is just binary cross entropy similified
obj_neg_loss = - F.logsigmoid(-obj_data_neg).sum()
with torch.no_grad():
pos_priors = priors.unsqueeze(0).expand(batch_size, -1, -1).reshape(-1, 4)[pos_mask, :]
boxes_pred = decode(loc_p, pos_priors, cfg.use_yolo_regressors)
boxes_targ = decode(loc_t, pos_priors, cfg.use_yolo_regressors)
iou_targets = elemwise_box_iou(boxes_pred, boxes_targ)
obj_pos_loss = - iou_targets * F.logsigmoid(obj_data_pos) - (1 - iou_targets) * F.logsigmoid(-obj_data_pos)
obj_pos_loss = obj_pos_loss.sum()
# All that was the objectiveness loss--now time for the class confidence loss
conf_data_pos = (conf_data[:, 1:])[pos_mask] # Now this has just 80 classes
conf_t_pos = conf_t[pos_mask] - 1 # So subtract 1 here
class_loss = F.cross_entropy(conf_data_pos, conf_t_pos, reduction='sum')
return cfg.conf_alpha * (class_loss + obj_pos_loss + obj_neg_loss)
def direct_mask_loss(self, pos_idx, idx_t, loc_data, mask_data, priors, masks):
""" Crops the gt masks using the predicted bboxes, scales them down, and outputs the BCE loss. """
loss_m = 0
for idx in range(mask_data.size(0)):
with torch.no_grad():
cur_pos_idx = pos_idx[idx, :, :]
cur_pos_idx_squeezed = cur_pos_idx[:, 1]
# Shape: [num_priors, 4], decoded predicted bboxes
pos_bboxes = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)
pos_bboxes = pos_bboxes[cur_pos_idx].view(-1, 4).clamp(0, 1)
pos_lookup = idx_t[idx, cur_pos_idx_squeezed]
cur_masks = masks[idx]
pos_masks = cur_masks[pos_lookup, :, :]
# Convert bboxes to absolute coordinates
num_pos, img_height, img_width = pos_masks.size()
# Take care of all the bad behavior that can be caused by out of bounds coordinates
x1, x2 = sanitize_coordinates(pos_bboxes[:, 0], pos_bboxes[:, 2], img_width)
y1, y2 = sanitize_coordinates(pos_bboxes[:, 1], pos_bboxes[:, 3], img_height)
# Crop each gt mask with the predicted bbox and rescale to the predicted mask size
# Note that each bounding box crop is a different size so I don't think we can vectorize this
scaled_masks = []
for jdx in range(num_pos):
tmp_mask = pos_masks[jdx, y1[jdx]:y2[jdx], x1[jdx]:x2[jdx]]
# Restore any dimensions we've left out because our bbox was 1px wide
while tmp_mask.dim() < 2:
tmp_mask = tmp_mask.unsqueeze(0)
new_mask = F.adaptive_avg_pool2d(tmp_mask.unsqueeze(0), cfg.mask_size)
scaled_masks.append(new_mask.view(1, -1))
mask_t = torch.cat(scaled_masks, 0).gt(0.5).float() # Threshold downsampled mask
pos_mask_data = mask_data[idx, cur_pos_idx_squeezed, :]
loss_m += F.binary_cross_entropy(torch.clamp(pos_mask_data, 0, 1), mask_t, reduction='sum') * cfg.mask_alpha
return loss_m
def coeff_diversity_loss(self, coeffs, instance_t):
"""
coeffs should be size [num_pos, num_coeffs]
instance_t should be size [num_pos] and be values from 0 to num_instances-1
"""
num_pos = coeffs.size(0)
instance_t = instance_t.view(-1) # juuuust to make sure
coeffs_norm = F.normalize(coeffs, dim=1)
cos_sim = coeffs_norm @ coeffs_norm.t()
inst_eq = (instance_t[:, None].expand_as(cos_sim) == instance_t[None, :].expand_as(cos_sim)).float()
# Rescale to be between 0 and 1
cos_sim = (cos_sim + 1) / 2
# If they're the same instance, use cosine distance, else use cosine similarity
loss = (1 - cos_sim) * inst_eq + cos_sim * (1 - inst_eq)
# Only divide by num_pos once because we're summing over a num_pos x num_pos tensor
# and all the losses will be divided by num_pos at the end, so just one extra time.
return cfg.mask_proto_coeff_diversity_alpha * loss.sum() / num_pos
def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'):
mask_h = proto_data.size(1)
mask_w = proto_data.size(2)
process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop
if cfg.mask_proto_remove_empty_masks:
# Make sure to store a copy of this because we edit it to get rid of all-zero masks
pos = pos.clone()
loss_m = 0
loss_d = 0 # Coefficient diversity loss
maskiou_t_list = []
maskiou_net_input_list = []
label_t_list = []
for idx in range(mask_data.size(0)):
with torch.no_grad():
downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w),
mode=interpolation_mode, align_corners=False).squeeze(0)
downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous()
if cfg.mask_proto_binarize_downsampled_gt:
downsampled_masks = downsampled_masks.gt(0.5).float()
if cfg.mask_proto_remove_empty_masks:
# Get rid of gt masks that are so small they get downsampled away
very_small_masks = (downsampled_masks.sum(dim=(0,1)) <= 0.0001)
for i in range(very_small_masks.size(0)):
if very_small_masks[i]:
pos[idx, idx_t[idx] == i] = 0
if cfg.mask_proto_reweight_mask_loss:
# Ensure that the gt is binary
if not cfg.mask_proto_binarize_downsampled_gt:
bin_gt = downsampled_masks.gt(0.5).float()
else:
bin_gt = downsampled_masks
gt_foreground_norm = bin_gt / (torch.sum(bin_gt, dim=(0,1), keepdim=True) + 0.0001)
gt_background_norm = (1-bin_gt) / (torch.sum(1-bin_gt, dim=(0,1), keepdim=True) + 0.0001)
mask_reweighting = gt_foreground_norm * cfg.mask_proto_reweight_coeff + gt_background_norm
mask_reweighting *= mask_h * mask_w
cur_pos = pos[idx]
pos_idx_t = idx_t[idx, cur_pos]
if process_gt_bboxes:
# Note: this is in point-form
if cfg.mask_proto_crop_with_pred_box:
pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos]
else:
pos_gt_box_t = gt_box_t[idx, cur_pos]
if pos_idx_t.size(0) == 0:
continue
proto_masks = proto_data[idx]
proto_coef = mask_data[idx, cur_pos, :]
if cfg.use_mask_scoring:
mask_scores = score_data[idx, cur_pos, :]
if cfg.mask_proto_coeff_diversity_loss:
if inst_data is not None:
div_coeffs = inst_data[idx, cur_pos, :]
else:
div_coeffs = proto_coef
loss_d += self.coeff_diversity_loss(div_coeffs, pos_idx_t)
# If we have over the allowed number of masks, select a random sample
old_num_pos = proto_coef.size(0)
if old_num_pos > cfg.masks_to_train:
perm = torch.randperm(proto_coef.size(0))
select = perm[:cfg.masks_to_train]
proto_coef = proto_coef[select, :]
pos_idx_t = pos_idx_t[select]
if process_gt_bboxes:
pos_gt_box_t = pos_gt_box_t[select, :]
if cfg.use_mask_scoring:
mask_scores = mask_scores[select, :]
num_pos = proto_coef.size(0)
mask_t = downsampled_masks[:, :, pos_idx_t]
label_t = labels[idx][pos_idx_t]
# Size: [mask_h, mask_w, num_pos]
pred_masks = proto_masks @ proto_coef.t()
pred_masks = cfg.mask_proto_mask_activation(pred_masks)
if cfg.mask_proto_double_loss:
if cfg.mask_proto_mask_activation == activation_func.sigmoid:
pre_loss = F.binary_cross_entropy(torch.clamp(pred_masks, 0, 1), mask_t, reduction='sum')
else:
pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='sum')
loss_m += cfg.mask_proto_double_loss_alpha * pre_loss
if cfg.mask_proto_crop:
pred_masks = crop(pred_masks, pos_gt_box_t)
if cfg.mask_proto_mask_activation == activation_func.sigmoid:
pre_loss = F.binary_cross_entropy(torch.clamp(pred_masks, 0, 1), mask_t, reduction='none')
else:
pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none')
if cfg.mask_proto_normalize_mask_loss_by_sqrt_area:
gt_area = torch.sum(mask_t, dim=(0, 1), keepdim=True)
pre_loss = pre_loss / (torch.sqrt(gt_area) + 0.0001)
if cfg.mask_proto_reweight_mask_loss:
pre_loss = pre_loss * mask_reweighting[:, :, pos_idx_t]
if cfg.mask_proto_normalize_emulate_roi_pooling:
weight = mask_h * mask_w if cfg.mask_proto_crop else 1
pos_gt_csize = center_size(pos_gt_box_t)
gt_box_width = pos_gt_csize[:, 2] * mask_w
gt_box_height = pos_gt_csize[:, 3] * mask_h
pre_loss = pre_loss.sum(dim=(0, 1)) / gt_box_width / gt_box_height * weight
# If the number of masks were limited scale the loss accordingly
if old_num_pos > num_pos:
pre_loss *= old_num_pos / num_pos
loss_m += torch.sum(pre_loss)
if cfg.use_maskiou:
if cfg.discard_mask_area > 0:
gt_mask_area = torch.sum(mask_t, dim=(0, 1))
select = gt_mask_area > cfg.discard_mask_area
if torch.sum(select) < 1:
continue
pos_gt_box_t = pos_gt_box_t[select, :]
pred_masks = pred_masks[:, :, select]
mask_t = mask_t[:, :, select]
label_t = label_t[select]
maskiou_net_input = pred_masks.permute(2, 0, 1).contiguous().unsqueeze(1)
pred_masks = pred_masks.gt(0.5).float()
maskiou_t = self._mask_iou(pred_masks, mask_t)
maskiou_net_input_list.append(maskiou_net_input)
maskiou_t_list.append(maskiou_t)
label_t_list.append(label_t)
losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}
if cfg.mask_proto_coeff_diversity_loss:
losses['D'] = loss_d
if cfg.use_maskiou:
# discard_mask_area discarded every mask in the batch, so nothing to do here
if len(maskiou_t_list) == 0:
return losses, None
maskiou_t = torch.cat(maskiou_t_list)
label_t = torch.cat(label_t_list)
maskiou_net_input = torch.cat(maskiou_net_input_list)
num_samples = maskiou_t.size(0)
if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train:
perm = torch.randperm(num_samples)
select = perm[:cfg.masks_to_train]
maskiou_t = maskiou_t[select]
label_t = label_t[select]
maskiou_net_input = maskiou_net_input[select]
return losses, [maskiou_net_input, maskiou_t, label_t]
return losses
def _mask_iou(self, mask1, mask2):
intersection = torch.sum(mask1*mask2, dim=(0, 1))
area1 = torch.sum(mask1, dim=(0, 1))
area2 = torch.sum(mask2, dim=(0, 1))
union = (area1 + area2) - intersection
ret = intersection / union
return ret
def mask_iou_loss(self, net, maskiou_targets):
maskiou_net_input, maskiou_t, label_t = maskiou_targets
maskiou_p = net.maskiou_net(maskiou_net_input)
label_t = label_t[:, None]
maskiou_p = torch.gather(maskiou_p, dim=1, index=label_t).view(-1)
loss_i = F.smooth_l1_loss(maskiou_p, maskiou_t, reduction='sum')
return loss_i * cfg.maskiou_alpha