-
Notifications
You must be signed in to change notification settings - Fork 126
/
test.py
106 lines (86 loc) · 4.37 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse, time, os
import imageio
import options.options as option
from utils import util
from solvers import create_solver
from data import create_dataloader
from data import create_dataset
def main():
parser = argparse.ArgumentParser(description='Test Super Resolution Models')
parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.')
opt = option.parse(parser.parse_args().opt)
opt = option.dict_to_nonedict(opt)
# initial configure
scale = opt['scale']
degrad = opt['degradation']
network_opt = opt['networks']
model_name = network_opt['which_model'].upper()
if opt['self_ensemble']: model_name += 'plus'
# create test dataloader
bm_names =[]
test_loaders = []
for _, dataset_opt in sorted(opt['datasets'].items()):
test_set = create_dataset(dataset_opt)
test_loader = create_dataloader(test_set, dataset_opt)
test_loaders.append(test_loader)
print('===> Test Dataset: [%s] Number of images: [%d]' % (test_set.name(), len(test_set)))
bm_names.append(test_set.name())
# create solver (and load model)
solver = create_solver(opt)
# Test phase
print('===> Start Test')
print("==================================================")
print("Method: %s || Scale: %d || Degradation: %s"%(model_name, scale, degrad))
for bm, test_loader in zip(bm_names, test_loaders):
print("Test set : [%s]"%bm)
sr_list = []
path_list = []
total_psnr = []
total_ssim = []
total_time = []
need_HR = False if test_loader.dataset.__class__.__name__.find('LRHR') < 0 else True
for iter, batch in enumerate(test_loader):
solver.feed_data(batch, need_HR=need_HR)
# calculate forward time
t0 = time.time()
solver.test()
t1 = time.time()
total_time.append((t1 - t0))
visuals = solver.get_current_visual(need_HR=need_HR)
sr_list.append(visuals['SR'])
# calculate PSNR/SSIM metrics on Python
if need_HR:
psnr, ssim = util.calc_metrics(visuals['SR'], visuals['HR'], crop_border=scale)
total_psnr.append(psnr)
total_ssim.append(ssim)
path_list.append(os.path.basename(batch['HR_path'][0]).replace('HR', model_name))
print("[%d/%d] %s || PSNR(dB)/SSIM: %.2f/%.4f || Timer: %.4f sec ." % (iter+1, len(test_loader),
os.path.basename(batch['LR_path'][0]),
psnr, ssim,
(t1 - t0)))
else:
path_list.append(os.path.basename(batch['LR_path'][0]))
print("[%d/%d] %s || Timer: %.4f sec ." % (iter + 1, len(test_loader),
os.path.basename(batch['LR_path'][0]),
(t1 - t0)))
if need_HR:
print("---- Average PSNR(dB) /SSIM /Speed(s) for [%s] ----" % bm)
print("PSNR: %.2f SSIM: %.4f Speed: %.4f" % (sum(total_psnr)/len(total_psnr),
sum(total_ssim)/len(total_ssim),
sum(total_time)/len(total_time)))
else:
print("---- Average Speed(s) for [%s] is %.4f sec ----" % (bm,
sum(total_time)/len(total_time)))
# save SR results for further evaluation on MATLAB
if need_HR:
save_img_path = os.path.join('./results/SR/'+degrad, model_name, bm, "x%d"%scale)
else:
save_img_path = os.path.join('./results/SR/'+bm, model_name, "x%d"%scale)
print("===> Saving SR images of [%s]... Save Path: [%s]\n" % (bm, save_img_path))
if not os.path.exists(save_img_path): os.makedirs(save_img_path)
for img, name in zip(sr_list, path_list):
imageio.imwrite(os.path.join(save_img_path, name), img)
print("==================================================")
print("===> Finished !")
if __name__ == '__main__':
main()