-
Notifications
You must be signed in to change notification settings - Fork 4
/
demo_deploy.py
104 lines (76 loc) · 3.25 KB
/
demo_deploy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from option import args
import model
import utils
import data.common as common
import torch
import numpy as np
import os
import glob
import cv2
device = torch.device('cpu' if args.cpu else 'cuda')
def deploy(args, sr_model):
img_ext = '.tif'
img_lists = glob.glob(os.path.join(args.dir_data, '*'+img_ext))
if len(img_lists) == 0:
print("Error: there are no images in given folder!")
if not os.path.exists(args.dir_out):
os.makedirs(args.dir_out)
with torch.no_grad():
for i in range(len(img_lists)):
print("[%d/%d] %s" % (i+1, len(img_lists), img_lists[i]))
# cls_labels = utils.make_labels(args, [os.path.split(img_lists[i])[-1]])
lr_np = cv2.imread(img_lists[i], cv2.IMREAD_COLOR)
lr_np = cv2.cvtColor(lr_np, cv2.COLOR_BGR2RGB)
if args.cubic_input:
lr_np = cv2.resize(lr_np, (lr_np.shape[0] * args.scale[0], lr_np.shape[1] * args.scale[0]),
interpolation=cv2.INTER_CUBIC)
lr = common.np2Tensor([lr_np], args.rgb_range)[0].unsqueeze(0)
if args.test_block:
# test block-by-block
b, c, h, w = lr.shape
factor = args.scale[0]
tp = args.patch_size
if not args.cubic_input:
ip = tp // factor
else:
ip = tp
assert h >= ip and w >= ip, 'LR input must be larger than the training inputs'
if not args.cubic_input:
sr = torch.zeros((b, c, h * factor, w * factor))
else:
sr = torch.zeros((b, c, h, w))
for iy in range(0, h, ip):
if iy + ip > h:
iy = h - ip
ty = factor * iy
for ix in range(0, w, ip):
if ix + ip > w:
ix = w - ip
tx = factor * ix
# forward-pass
lr_p = lr[:, :, iy:iy + ip, ix:ix + ip]
lr_p = lr_p.to(device)
sr_p = sr_model(lr_p)
sr[:, :, ty:ty + tp, tx:tx + tp] = sr_p
else:
lr = lr.to(device)
sr = sr_model(lr)
sr_np = np.array(sr.cpu().detach())
final_sr = sr_np[0, :].transpose([1, 2, 0])
if args.rgb_range == 1:
final_sr = np.clip(final_sr * 255, 0, args.rgb_range * 255)
else:
final_sr = np.clip(final_sr, 0, args.rgb_range)
final_sr = final_sr.astype(np.uint8)
final_sr = cv2.cvtColor(final_sr, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(args.dir_out, os.path.split(img_lists[i])[-1]), final_sr)
if __name__ == '__main__':
# args parameter setting
# UCMerced data
args.pre_train = '../experiment/HSENETx4_UCMerced/model/model_best.pt'
args.dir_data = 'F:/research/dataset/SR for remote sensing/UCMerced_LandUse/test/LR_x4'
args.dir_out = '../experiment/results/HSENETx4_UCMerced'
checkpoint = utils.checkpoint(args)
sr_model = model.Model(args, checkpoint)
sr_model.eval()
deploy(args, sr_model)