-
Notifications
You must be signed in to change notification settings - Fork 14
/
demo.py
74 lines (54 loc) · 2.33 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from torchvision import transforms
from modules.unet import UNet, UNetReshade
import PIL
from PIL import Image
import argparse
import os.path
from pathlib import Path
import glob
import sys
import pdb
parser = argparse.ArgumentParser(description='Visualize output for a single Task')
parser.add_argument('--task', dest='task', help="normal, depth or reshading")
parser.set_defaults(task='NONE')
parser.add_argument('--img_path', dest='img_path', help="path to rgb image")
parser.set_defaults(im_name='NONE')
parser.add_argument('--output_path', dest='output_path', help="path to where output image should be stored")
parser.set_defaults(store_name='NONE')
args = parser.parse_args()
root_dir = './models/'
trans_totensor = transforms.Compose([transforms.Resize(256, interpolation=PIL.Image.BILINEAR),
transforms.CenterCrop(256),
transforms.ToTensor()])
trans_topil = transforms.ToPILImage()
os.system(f"mkdir -p {args.output_path}")
# get target task and model
target_tasks = ['normal','depth','reshading']
try:
task_index = target_tasks.index(args.task)
except:
print("task should be one of the following: normal, depth, reshading")
sys.exit()
models = [UNet(), UNet(downsample=6, out_channels=1), UNetReshade(downsample=5)]
model = models[task_index]
map_location = (lambda storage, loc: storage.cuda()) if torch.cuda.is_available() else torch.device('cpu')
def save_outputs(img_path, output_file_name):
img = Image.open(img_path)
img_tensor = trans_totensor(img)[:3].unsqueeze(0)
# compute baseline and consistency output
for type in ['baseline','consistency']:
path = root_dir + 'rgb2'+args.task+'_'+type+'.pth'
model_state_dict = torch.load(path, map_location=map_location)
model.load_state_dict(model_state_dict)
baseline_output = model(img_tensor).clamp(min=0, max=1)
trans_topil(baseline_output[0]).save(args.output_path+'/'+output_file_name+'_'+args.task+'_'+type+'.png')
img_path = Path(args.img_path)
if img_path.is_file():
save_outputs(args.img_path, os.path.splitext(os.path.basename(args.img_path))[0])
elif img_path.is_dir():
for f in glob.glob(args.img_path+'/*'):
save_outputs(f, os.path.splitext(os.path.basename(f))[0])
else:
print("invalid file path!")
sys.exit()