-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
68 lines (51 loc) · 2.17 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from inclusion import *
def dice_loss(input, target):
input = torch.sigmoid(input)
smooth = 1.0
iflat = input.view(-1)
tflat = target.view(-1)
intersection = (iflat * tflat).sum()
return ((2.0 * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth))
class FocalLoss(nn.Module):
def __init__(self, gamma):
super().__init__()
self.gamma = gamma
def forward(self, input, target):
if not (target.size() == input.size()):
raise ValueError("Target size ({}) must be the same as input size ({})"
.format(target.size(), input.size()))
max_val = (-input).clamp(min=0)
loss = input - input * target + max_val + \
((-max_val).exp() + (-input - max_val).exp()).log()
invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
loss = (invprobs * self.gamma).exp() * loss
return loss.mean()
class MixedLoss(nn.Module):
def __init__(self, alpha, gamma):
super().__init__()
self.alpha = alpha
self.focal = FocalLoss(gamma)
def forward(self, input, target):
loss = self.alpha*self.focal(input, target) - torch.log(dice_loss(input, target))
return loss.mean()
def dice(pred, targs):
pred = (pred>0).float()
return 2.0 * (pred*targs).sum() / ((pred+targs).sum() + 1.0)
def IoU(pred, targs):
pred = (pred>0).float()
intersection = (pred*targs).sum()
return intersection / ((pred+targs).sum() - intersection + 1.0)
class LossBinary:
def __init__(self, jaccard_weight=0):
self.nll_loss = nn.BCEWithLogitsLoss()
self.jaccard_weight = jaccard_weight
def __call__(self, outputs, targets):
loss = self.nll_loss(outputs, targets)
if self.jaccard_weight:
eps = 1e-15
jaccard_target = (targets == 1.0).float()
jaccard_output = torch.sigmoid(outputs)
intersection = (jaccard_output * jaccard_target).sum()
union = jaccard_output.sum() + jaccard_target.sum()
loss -= self.jaccard_weight * torch.log((intersection + eps) / (union - intersection + eps))
return loss