forked from mynlp/cst_captioning
-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
executable file
·77 lines (58 loc) · 1.99 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import torch
import numpy as np
import os
import sys
import time
import math
import json
import logging
from datetime import datetime
from dataloader import DataLoader
from model import CaptionModel, CrossEntropyCriterion
from train import test
import utils
import opts
logger = logging.getLogger(__name__)
if __name__ == '__main__':
opt = opts.parse_opts()
logging.basicConfig(level=getattr(logging, opt.loglevel.upper()),
format='%(asctime)s:%(levelname)s: %(message)s')
logger.info(
'Input arguments: %s',
json.dumps(
vars(opt),
sort_keys=True,
indent=4))
start = datetime.now()
test_opt = {'label_h5': opt.test_label_h5,
'batch_size': opt.test_batch_size,
'feat_h5': opt.test_feat_h5,
'cocofmt_file': opt.test_cocofmt_file,
'seq_per_img': opt.test_seq_per_img,
'num_chunks': opt.num_chunks,
'mode': 'test'
}
test_loader = DataLoader(test_opt)
logger.info('Loading model: %s', opt.model_file)
checkpoint = torch.load(opt.model_file)
checkpoint_opt = checkpoint['opt']
opt.model_type = checkpoint_opt.model_type
opt.vocab = checkpoint_opt.vocab
opt.vocab_size = checkpoint_opt.vocab_size
opt.seq_length = checkpoint_opt.seq_length
opt.feat_dims = checkpoint_opt.feat_dims
assert opt.vocab_size == test_loader.get_vocab_size()
assert opt.seq_length == test_loader.get_seq_length()
assert opt.feat_dims == test_loader.get_feat_dims()
logger.info('Building model...')
model = CaptionModel(opt)
logger.info('Loading state from the checkpoint...')
model.load_state_dict(checkpoint['model'])
xe_criterion = CrossEntropyCriterion()
model.cuda()
xe_criterion.cuda()
logger.info('Start testing...')
test(model, xe_criterion, test_loader, opt)
logger.info('Time: %s', datetime.now() - start)
test_loader.close()