forked from JiawangBian/sc_depth_pl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
83 lines (60 loc) · 2.25 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
from tqdm import tqdm
import torch
from imageio import imread, imwrite
from path import Path
import os
from config import get_opts, get_training_size
from SC_Depth import SC_Depth
from SC_DepthV2 import SC_DepthV2
from SC_DepthV3 import SC_DepthV3
import datasets.custom_transforms as custom_transforms
from visualization import *
@torch.no_grad()
def main():
hparams = get_opts()
if hparams.model_version == 'v1':
system = SC_Depth(hparams)
elif hparams.model_version == 'v2':
system = SC_DepthV2(hparams)
elif hparams.model_version == 'v3':
system = SC_DepthV3(hparams)
system = system.load_from_checkpoint(hparams.ckpt_path, strict=False)
model = system.depth_net
model.cuda()
model.eval()
# training size
training_size = get_training_size(hparams.dataset_name)
# normalization
inference_transform = custom_transforms.Compose([
custom_transforms.RescaleTo(training_size),
custom_transforms.ArrayToTensor(),
custom_transforms.Normalize()]
)
input_dir = Path(hparams.input_dir)
output_dir = Path(hparams.output_dir) / \
'model_{}'.format(hparams.model_version)
output_dir.makedirs_p()
if hparams.save_vis:
(output_dir/'vis').makedirs_p()
if hparams.save_depth:
(output_dir/'depth').makedirs_p()
image_files = sum([(input_dir).files('*.{}'.format(ext))
for ext in ['jpg', 'png']], [])
image_files = sorted(image_files)
print('{} images for inference'.format(len(image_files)))
for i, img_file in enumerate(tqdm(image_files)):
filename = os.path.splitext(os.path.basename(img_file))[0]
img = imread(img_file).astype(np.float32)
tensor_img = inference_transform([img])[0][0].unsqueeze(0).cuda()
pred_depth = model(tensor_img)
if hparams.save_vis:
vis = visualize_depth(pred_depth[0, 0]).permute(
1, 2, 0).numpy() * 255
imwrite(output_dir/'vis/{}.jpg'.format(filename),
vis.astype(np.uint8))
if hparams.save_depth:
depth = pred_depth[0, 0].cpu().numpy()
np.save(output_dir/'depth/{}.npy'.format(filename), depth)
if __name__ == '__main__':
main()