-
Notifications
You must be signed in to change notification settings - Fork 1
/
OHEM.py
68 lines (52 loc) · 2.49 KB
/
OHEM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from torch import nn
from torch.nn import functional as F
class OHEM(nn.Module):
""" Online hard example mining.
Needs input from nn.LogSotmax() """
def __init__(self, ratio,ignore_index, min_kept=100000):
super(OHEM, self).__init__()
assert ratio > 0
self.ratio = ratio
self.ignore_index=ignore_index
assert min_kept > 1
self.min_kept=min_kept
def forward(self, seg_logit, seg_label):
"""Sample pixels that have high loss or with low prediction confidence.
Args:
seg_logit (torch.Tensor): segmentation logits, shape (N, C, H, W)
seg_label (torch.Tensor): segmentation label, shape (N, 1, H, W)
Returns:
torch.Tensor: segmentation weight, shape (N, H, W)
"""
with torch.no_grad():
assert seg_logit.shape[2:] == seg_label.shape[2:]
assert seg_label.shape[1] == 1
seg_label = seg_label.squeeze(1).long()
batch_kept = self.min_kept * seg_label.size(0)
valid_mask = seg_label != self.ignore_index
seg_weight = seg_logit.new_zeros(size=seg_label.size())
valid_seg_weight = seg_weight[valid_mask]
if self.thresh is not None:
seg_prob = F.softmax(seg_logit, dim=1)
tmp_seg_label = seg_label.clone().unsqueeze(1)
tmp_seg_label[tmp_seg_label == self.context.ignore_index] = 0
seg_prob = seg_prob.gather(1, tmp_seg_label).squeeze(1)
sort_prob, sort_indices = seg_prob[valid_mask].sort()
if sort_prob.numel() > 0:
min_threshold = sort_prob[min(batch_kept,
sort_prob.numel() - 1)]
else:
min_threshold = 0.0
threshold = max(min_threshold, self.thresh)
valid_seg_weight[seg_prob[valid_mask] < threshold] = 1.
seg_weight[valid_mask] = valid_seg_weight
return seg_weight
class CrossEntropyLoss2dPixelWiseWeighted(nn.Module):
def __init__(self, weight=None, ignore_index=250, reduction='none'):
super(CrossEntropyLoss2dPixelWiseWeighted, self).__init__()
self.CE = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction=reduction)
def forward(self, output, target, pixelWiseWeight):
loss = self.CE(output, target)
loss = torch.mean(loss * pixelWiseWeight)
return loss