-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
115 lines (98 loc) · 3.76 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import os
import random
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import pretrainedmodels
from deepml import datasets
from deepml.models import CNNs
from deepml.utils import libs, runner
# MODIFY HERE!!!!
# list of data paths
DATA_PATHS = {
'Cub': './data/cub_200_2011',
'Stanford': './data/stanford',
'Car': './data/cars196'
}
def main(args):
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
# setup GPUs or CPUs
num_device = torch.cuda.device_count()
gpu_id = torch.cuda.current_device() if num_device > 0 else -1
device = torch.device(('cuda:%d' % gpu_id) if gpu_id >= 0 else 'cpu')
# build model
model = CNNs(
out_dim=args.outdim,
arch=args.arch,
pretrained=args.pretrained,
normalized=args.normalized
)
inverted = (model.base.input_space == 'BGR')
if os.path.isfile(args.checkpoint):
checkpoint = torch.load(
args.checkpoint,
map_location=lambda storage, loc: storage)
print("=> Loaded checkpoint '{}'".format(args.checkpoint))
model.load_state_dict(checkpoint['state_dict'])
else:
raise ValueError(
"=> No checkpoint found at '{}'".format(args.checkpoint))
model = model.to(device)
# setup data set
data_path = os.path.abspath(DATA_PATHS[args.data])
data = datasets.__dict__[args.data](data_path)
test_loader = DataLoader(
data.get_dataset(
ttype='test',
inverted=inverted,
transform=libs.get_data_augmentation(
img_size=args.img_size,
mean=model.base.mean,
std=model.base.std,
ttype='test'
)
),
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=num_device > 0
)
args.device = device
args.print_freq = 10
# test the model
runner.test(test_loader, model, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Deep metric learning')
parser.add_argument('--data', default='Cub', required=True,
help='name of the dataset')
parser.add_argument('-a', '--arch', metavar='ARCH', default='bninception',
choices=pretrainedmodels.model_names,
help='model architecture: ' +
' | '.join(pretrainedmodels.model_names) +
' (default: bninception)')
parser.add_argument('-p', '--pretrained', metavar='PRET',
default='imagenet',
help='use pre-trained model')
parser.add_argument('-c', '--checkpoint', type=str,
default='./output/model_best.pth.tar', metavar='PATH')
parser.add_argument('-img_size', default=224, type=int,
help='image shape (default: 224)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 8)')
parser.add_argument('--outdim', default=512, type=int, metavar='N',
help='number of features')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N',
help='mini-batch size (default: 256)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--normalized', dest='normalized',
action='store_true', help='normalize the last layer')
args = parser.parse_args()
main(args)
print('Test was done!')