-
Notifications
You must be signed in to change notification settings - Fork 0
/
val_pvnet.py
103 lines (79 loc) · 3.68 KB
/
val_pvnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# -*- coding: utf-8 -*-
import config
import torch
import os.path as osp
from utils import meter
from torch import nn
from torch import optim
from models import PVNet2, PVNet2_v4, PVNet2_v5, PVNet2_v7, PVNet2_v9
from torch.utils.data import DataLoader
from datasets import *
def validate(val_loader, net, epoch, print_pr=False):
"""
validation for one epoch on the val set
"""
batch_time = meter.TimeMeter(True)
data_time = meter.TimeMeter(True)
prec = meter.ClassErrorMeter(topk=[1], accuracy=True)
retrieval_map = meter.RetrievalMAPMeter()
# testing mode
net.eval()
total_seen_class = [0 for _ in range(40)]
total_right_class = [0 for _ in range(40)]
for i, (views, pcs, labels) in enumerate(val_loader):
batch_time.reset()
views = views.to(device=config.device)
pcs = pcs.to(device=config.device)
labels = labels.to(device=config.device)
preds, fts = net(pcs, views, get_fea=True) # bz x C x H x W
# prec.add(preds.data, labels.data)
prec.add(preds.data, labels.data)
retrieval_map.add(fts.detach()/torch.norm(fts.detach(), 2, 1, True), labels.detach())
for j in range(views.size(0)):
total_seen_class[labels.data[j]] += 1
total_right_class[labels.data[j]] += (np.argmax(preds.data,1)[j] == labels.cpu()[j])
if i % config.print_freq == 0:
print(f'Epoch: [{epoch}][{i}/{len(val_loader)}]\t'
f'Batch Time {batch_time.value():.3f}\t'
f'Epoch Time {data_time.value():.3f}\t'
f'Prec@1 {prec.value(1):.3f}\t'
f'Mean Class accuracy {(np.mean(np.array(total_right_class)/np.array(total_seen_class,dtype=np.float))):.3f}')
mAP = retrieval_map.mAP()
print(f' instance accuracy at epoch {epoch}: {prec.value(1)} ')
print(f' mean class accuracy at epoch {epoch}: {(np.mean(np.array(total_right_class)/np.array(total_seen_class,dtype=np.float)))} ')
print(f' map at epoch {epoch}: {mAP} ')
if print_pr:
print(f'pr: {retrieval_map.pr()}')
return prec.value(1), mAP
def main():
print('Training Process\nInitializing...\n')
config.init_env()
val_dataset = pc_view_data(config.pv_net.pc_root,
config.pv_net.view_root,
status=STATUS_TEST,
base_model_name=config.base_model_name,
view_idx=config.pv_net.view_idx,
pc_input_num=config.pv_net.pc_input_num)
val_loader = DataLoader(val_dataset, batch_size=config.pv_net.train.batch_sz,
num_workers=config.num_workers,shuffle=True)
# create model
net = PVNet2_v9(pc_input_num=config.pv_net.pc_input_num)
net = torch.nn.DataParallel(net)
net = net.to(device=config.device)
optimizer_all = optim.SGD(net.parameters(), config.pv_net.train.all_lr,
momentum=config.pv_net.train.momentum,
weight_decay=config.pv_net.train.weight_decay)
print(f'loading pretrained model from {config.pv_net.ckpt_file}')
checkpoint = torch.load(config.pv_net.ckpt_file)
state_dict = checkpoint['model']
# net.module.load_state_dict({k[7:]: v for k, v in state_dict.items()})
net.module.load_state_dict(state_dict)
optimizer_all.load_state_dict(checkpoint['optimizer_all'])
best_prec1 = checkpoint['best_prec1']
resume_epoch = checkpoint['epoch']
with torch.no_grad():
prec1, Map = validate(val_loader, net, resume_epoch, print_pr=False)
print('curr accuracy: ', prec1)
print('best accuracy: ', best_prec1)
if __name__ == '__main__':
main()