-
Notifications
You must be signed in to change notification settings - Fork 5
/
test.py
112 lines (92 loc) · 3.8 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import sys
import argparse
import logging
import random
import torch
import gorilla
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, 'provider'))
sys.path.append(os.path.join(BASE_DIR, 'model'))
sys.path.append(os.path.join(BASE_DIR, 'model', 'pointnet2'))
sys.path.append(os.path.join(BASE_DIR, 'utils'))
from DPDN import Net
from solver import test_func, get_logger
from dataset import TestDataset
from evaluation_utils import evaluate
def get_parser():
parser = argparse.ArgumentParser(
description="Pose Estimation")
# pretrain
parser.add_argument("--gpus",
type=str,
default="0",
help="gpu num")
parser.add_argument("--config",
type=str,
default="config/supervised.yaml",
help="path to config file")
parser.add_argument("--test_epoch",
type=int,
default=30,
help="test epoch")
parser.add_argument('--mask_label', action='store_true', default=False,
help='whether having mask labels of real data')
parser.add_argument('--only_eval', action='store_true', default=False,
help='whether directly evaluating the results')
args_cfg = parser.parse_args()
return args_cfg
def init():
args = get_parser()
exp_name = args.config.split("/")[-1].split(".")[0]
log_dir = os.path.join("log", exp_name)
cfg = gorilla.Config.fromfile(args.config)
cfg.exp_name = exp_name
cfg.log_dir = log_dir
cfg.gpus = args.gpus
cfg.test_epoch = args.test_epoch
cfg.mask_label = args.mask_label
cfg.only_eval = args.only_eval
gorilla.utils.set_cuda_visible_devices(gpu_ids = cfg.gpus)
logger = get_logger(level_print=logging.INFO, level_save=logging.WARNING, path_file=log_dir+"/test_epoch" + str(cfg.test_epoch) + "_logger.log")
return logger, cfg
if __name__ == "__main__":
logger, cfg = init()
logger.warning("************************ Start Logging ************************")
logger.info(cfg)
logger.info("using gpu: {}".format(cfg.gpus))
random.seed(cfg.rd_seed)
torch.manual_seed(cfg.rd_seed)
if cfg.setting == 'supervised':
save_path = os.path.join(cfg.log_dir, 'eval_epoch' + str(cfg.test_epoch))
setting = 'supervised'
else:
if cfg.mask_label:
save_path = os.path.join(cfg.log_dir, 'eval_withMaskLabel_epoch' + str(cfg.test_epoch))
setting = 'unsupervised_withMask'
else:
save_path = os.path.join(cfg.log_dir, 'eval_woMaskLabel_epoch' + str(cfg.test_epoch))
setting = 'unsupervised'
if not cfg.only_eval:
if not os.path.isdir(save_path):
os.mkdir(save_path)
# model
logger.info("=> creating model ...")
model = Net(cfg.num_category, cfg.num_prior)
if len(cfg.gpus)>1:
model = torch.nn.DataParallel(model, range(len(cfg.gpus.split(","))))
model = model.cuda()
checkpoint = os.path.join(cfg.log_dir, 'epoch_' + str(cfg.test_epoch) + '.pth')
logger.info("=> loading checkpoint from path: {} ...".format(checkpoint))
gorilla.solver.load_checkpoint(model=model, filename=checkpoint)
# data loader
dataset = TestDataset(cfg.test, BASE_DIR, setting)
dataloder = torch.utils.data.DataLoader(
dataset,
batch_size=1,
num_workers=8,
shuffle=False,
drop_last=False
)
test_func(model, dataloder, save_path)
evaluate(save_path, logger)