-
Notifications
You must be signed in to change notification settings - Fork 8
/
main_predict.py
115 lines (87 loc) · 3.15 KB
/
main_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import sys
import os
import requests
import torch
import numpy as np
import cv2
#import matplotlib.pyplot as plt
from PIL import Image
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from einops import rearrange, reduce, repeat
import time
from dataloader.dataloader_evaluate import *
import datetime as dt
import time
import torch.nn.functional as F
import yaml
import argparse
import model.ViT_MAE as VIT_MAE_Origin
from torchvision.utils import save_image
# define the utils
imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])
def get_args_parser():
parser = argparse.ArgumentParser('TC-MoA', add_help=False)
parser.add_argument('--config_path', default='config/predict.yaml', type=str,
help='config_path to load')
return parser
def test_one_iter(model,A,B,task_index,config):
A = A.to(device, non_blocking=True)
B = B.to(device, non_blocking=True)
A = A.unsqueeze(0)
B = B.unsqueeze(0)
with torch.cuda.amp.autocast():
loss_dict, pred,att_tuple= model(A, B ,task_index)
loss = loss_dict["loss"]
return loss_dict,pred,att_tuple
# load an image
def main(output_dir,model,config):
#model.eval()
Evaluate_dataset = EvaluateDataSet(config['EvalDataSet'],config)
model = model.eval()
with torch.no_grad():
i = 0
time_start = time.time()
for item in Evaluate_dataset:
A,B,AB_info =item
if i%100 == 0:
print("has done img num:", i)
loss_AB,pred,_= test_one_iter(model,A,B,AB_info["task_index"],config)
Evaluate_dataset.save_img_NewLoader(pred.cpu(),output_dir,
AB_info)
i+=1
time_end = time.time()
time_sum = time_end - time_start
print("Time Used:", time_sum)
print("Done!")
def prepare_model(model_select,chkpt_dir, config):
# build model
arch = config["model_type"]
if model_select == "Base":
print("model_type: Base")
models_mae = VIT_MAE_Origin
model = getattr(models_mae, arch)(config).to(config["device"])
# load model
checkpoint = torch.load(chkpt_dir)
msg = model.load_state_dict(checkpoint['model'], strict=False)
return model
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
with open(args.config_path, 'r') as stream:
config = yaml.safe_load(stream)
waiting_time = config["waiting_time"]
upsample = True
print(dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
print("Waiting Hours: " ,waiting_time)
time.sleep(waiting_time*3600)
device=config['device']
for ckpt_name,model_select in config["ckpt_dict"].items():
model_mae = prepare_model(model_select,os.path.join(config["chkpt_dir"],ckpt_name), config).to(device)
print('Model loaded.',device)
MoreDetail = config["more_detail"]
model_type = ckpt_name +"_"+ MoreDetail
print("model_type:",model_type)
output_dir = os.path.join(config["result_path"],model_type)
main(output_dir,model_mae,config)