-
Notifications
You must be signed in to change notification settings - Fork 4
/
loss.py
94 lines (74 loc) · 2.9 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
class CharbonnierLoss(nn.Module):
def __init__(self, eps=1e-6):
super(CharbonnierLoss, self).__init__()
self.eps2 = eps ** 2
def forward(self, inp, target):
return ((nn.functional.mse_loss(inp, target, reduction='none') + self.eps2) ** .5).mean()
class OutlierAwareLoss(nn.Module):
def __init__(self, kernel_size=False):
super(OutlierAwareLoss, self).__init__()
self.unfold = torch.nn.Unfold(kernel_size)
self.kernel = kernel_size
def forward(self, out, lab):
b, c, h, w = out.shape
p = self.kernel // 2
delta = out - lab
if self.kernel:
delta_ = torch.nn.functional.pad(delta, (p, p, p, p))
patch = self.unfold(delta_).reshape(b, c,
self.kernel, self.kernel,
h, w).detach()
var = patch.std((2, 3)) / (2 ** .5)
avg = patch.mean((2, 3))
else:
var = delta.std((2, 3), keepdims=True) / (2 ** .5)
avg = delta.mean((2, 3), True)
weight = 1 - (-((delta - avg).abs() / var)).exp().detach()
# weight = 1 - (-1 / var.detach()-1/ delta.abs().detach()).exp()
loss = (delta.abs() * weight).mean()
return loss
class LossWarmup(nn.Module):
def __init__(self):
super(LossWarmup, self).__init__()
self.loss_cb = CharbonnierLoss(1e-8)
self.loss_cs = nn.CosineSimilarity()
def forward(self, inp, gt, warmup1, warmup2):
loss = self.loss_cb(warmup2, inp) + \
(self.loss_cb(warmup1, gt) + (1 - self.loss_cs(warmup1.clip(0, 1), gt)).mean())
return loss
class LossISP(nn.Module):
def __init__(self):
super(LossISP, self).__init__()
self.loss_cs = nn.CosineSimilarity()
self.loss_oa = OutlierAwareLoss()
def forward(self, out, gt):
loss = (self.loss_oa(out, gt) + (1 - self.loss_cs(out.clip(0, 1), gt)).mean())
return loss
class LossLLE(nn.Module):
def __init__(self):
super(LossLLE, self).__init__()
self.loss_cs = nn.CosineSimilarity()
self.loss_oa = OutlierAwareLoss()
def forward(self, out, gt):
loss = (self.loss_oa(out, gt) + (1 - self.loss_cs(out.clip(0, 1), gt)).mean())
return loss
class LossSR(nn.Module):
def __init__(self):
super(LossSR, self).__init__()
self.loss_oa = OutlierAwareLoss()
def forward(self, out, gt):
loss = self.loss_oa(out, gt)
return loss
def import_loss(training_task):
if training_task == 'isp':
return LossISP()
elif training_task == 'lle':
return LossLLE()
elif training_task == 'sr':
return LossSR()
elif training_task == 'warmup':
return LossWarmup()
else:
raise ValueError('unknown training task, please choose from [isp, lle, sr, warmup].')