-
Notifications
You must be signed in to change notification settings - Fork 15
/
metrics.py
126 lines (93 loc) · 4.76 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import utils.spatial_color_alignment as sca_utils
from utils.warp import warp
class L2(nn.Module):
def __init__(self, boundary_ignore=None):
super().__init__()
self.boundary_ignore = boundary_ignore
def forward(self, pred, gt, valid=None):
if self.boundary_ignore is not None:
pred = pred[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]
gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]
if valid is not None:
valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]
pred_m = pred
gt_m = gt
if valid is None:
mse = F.mse_loss(pred_m, gt_m)
else:
mse = F.mse_loss(pred_m, gt_m, reduction='none')
eps = 1e-12
elem_ratio = mse.numel() / valid.numel()
mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)
return mse
class PSNR(nn.Module):
def __init__(self, boundary_ignore=None, max_value=1.0):
super().__init__()
self.l2 = L2(boundary_ignore=boundary_ignore)
self.max_value = max_value
def psnr(self, pred, gt, valid=None):
mse = self.l2(pred, gt, valid=valid)
psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10()
return psnr
def forward(self, pred, gt, valid=None):
assert pred.dim() == 4 and pred.shape == gt.shape
if valid is None:
psnr_all = [self.psnr(p.unsqueeze(0), g.unsqueeze(0)) for p, g in
zip(pred, gt)]
else:
psnr_all = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), v.unsqueeze(0)) for p, g, v in zip(pred, gt, valid)]
psnr = sum(psnr_all) / len(psnr_all)
return psnr
class AlignedL2(nn.Module):
def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None):
super().__init__()
self.sr_factor = sr_factor
self.boundary_ignore = boundary_ignore
self.alignment_net = alignment_net
self.gauss_kernel, self.ksz = sca_utils.get_gaussian_kernel(sd=1.5)
def forward(self, pred, gt, burst_input):
# Estimate flow between the prediction and the ground truth
with torch.no_grad():
flow = self.alignment_net(pred / (pred.max() + 1e-6), gt / (gt.max() + 1e-6))
# Warp the prediction to the ground truth coordinates
pred_warped = warp(pred, flow)
# Warp the base input frame to the ground truth. This will be used to estimate the color transformation between
# the input and the ground truth
sr_factor = self.sr_factor
ds_factor = 1.0 / float(2.0 * sr_factor)
flow_ds = F.interpolate(flow, scale_factor=ds_factor, mode='bilinear') * ds_factor
burst_0 = burst_input[:, 0, [0, 1, 3]].contiguous()
burst_0_warped = warp(burst_0, flow_ds)
frame_gt_ds = F.interpolate(gt, scale_factor=ds_factor, mode='bilinear')
# Match the colorspace between the prediction and ground truth
pred_warped_m, valid = sca_utils.match_colors(frame_gt_ds, burst_0_warped, pred_warped, self.ksz,
self.gauss_kernel)
# Ignore boundary pixels if specified
if self.boundary_ignore is not None:
pred_warped_m = pred_warped_m[..., self.boundary_ignore:-self.boundary_ignore,
self.boundary_ignore:-self.boundary_ignore]
gt = gt[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]
valid = valid[..., self.boundary_ignore:-self.boundary_ignore, self.boundary_ignore:-self.boundary_ignore]
# Estimate MSE
mse = F.mse_loss(pred_warped_m, gt, reduction='none')
eps = 1e-12
elem_ratio = mse.numel() / valid.numel()
mse = (mse * valid.float()).sum() / (valid.float().sum()*elem_ratio + eps)
return mse
class AlignedPSNR(nn.Module):
def __init__(self, alignment_net, sr_factor=4, boundary_ignore=None, max_value=1.0):
super().__init__()
self.l2 = AlignedL2(alignment_net=alignment_net, sr_factor=sr_factor, boundary_ignore=boundary_ignore)
self.max_value = max_value
def psnr(self, pred, gt, burst_input):
mse = self.l2(pred, gt, burst_input)
psnr = 20 * math.log10(self.max_value) - 10.0 * mse.log10()
return psnr
def forward(self, pred, gt, burst_input):
psnr_all = [self.psnr(p.unsqueeze(0), g.unsqueeze(0), bi.unsqueeze(0)) for p, g, bi in zip(pred, gt, burst_input)]
psnr = sum(psnr_all) / len(psnr_all)
return psnr