-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
126 lines (103 loc) · 3.8 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import re
import torch
import numpy as np
def sec_to_hms(t):
"""
Convert time seconds to three ints representing hours, minutes and seconds
:param t: time in seconds only
:return: three ints for hours, minutes, and seconds, respectively
"""
m, s = divmod(t, 60)
h, m = divmod(m, 60)
return int(h), int(m), int(s)
def sec_to_hms_str(t):
"""
Convert time in seconds to a string consisting of hours, minutes and seconds
:param t: time in seconds only
:return: a string in the format of #h#m#s, e.g. 5h10m24s
"""
h, m, s = sec_to_hms(t)
time_str = "%dh%dm%ds" % (h, m, s)
return time_str
def compute_disp_error(pred_disp, gt_disp):
"""
Calculate disparity error metrics for the whole batch
:param pred_disp: predicted disparity
:param gt_disp: ground truth disparity
:return: total EPE, total bad3 and error maps for the whole batch
"""
valid_mask = gt_disp > 0
valid_pixels = torch.sum(valid_mask)
diff = torch.abs(pred_disp - gt_disp)
diff = torch.mul(valid_mask, diff)
total_epe = torch.sum(diff) / valid_pixels
total_bad3 = diff > 3.0
total_bad3 = torch.sum(total_bad3) / valid_pixels * 100.0
return total_epe, total_bad3, diff
def readPFM(file):
"""
A method to read PFM file into a numpy array. Taken directly from the SceneFlow website
(https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlow/assets/code/python_pfm.py)
:param file: directory to the PFM file
:return: a decoded numpy array based on the PFM file
"""
file = open(file, 'rb')
header = file.readline().rstrip()
if header.decode("ascii") == 'PF':
color = True
elif header.decode("ascii") == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
if dim_match:
width, height = list(map(int, dim_match.groups()))
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().decode("ascii").rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data, scale
def post_process(pred_disp, occlusion, thres):
"""
Post process the predicted disparity maps
:param pred_disp: predicted disparity map
:param occlusion: soft occlusion mask
:param thres: threshold to filter the occlusion mask
:return: disparity map after post processing
"""
batch, _, height, width = pred_disp.size()
window = 10
validity = torch.clone(occlusion).detach()
validity[validity < thres] = 0
validity[validity > 0] = 1
final_disp = torch.clone(pred_disp).detach()
for i in range(1, width, 1):
if i < window:
local_window = i
else:
local_window = window
prev_col = final_disp[:, :, :, i - local_window:i]
avg = torch.mean(prev_col, dim=3)
final_disp[:, :, :, i] = validity[:, :, :, i] * final_disp[:, :, :, i] + (1 - validity[:, :, :, i]) * avg
return final_disp
def unpad_imgs(inputs, outputs):
"""
Undo padding on images
:param inputs: inputs to the model
:param outputs: prediction from the model
:return:
"""
for k, v in inputs.items():
if k != "frame_id" and k != "left_pad" and k != "top_pad":
inputs[k] = v[:, :, inputs['top_pad']:, inputs['left_pad']:]
outputs['refined_disp0'] = outputs['refined_disp0'][:, :, inputs['top_pad']:, inputs['left_pad']:]
if "occ0" in outputs:
outputs['occ0'] = outputs['occ0'][:, :, inputs['top_pad']:, inputs['left_pad']:]