-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate.py
executable file
·125 lines (110 loc) · 4 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import numpy as np
from fairseq import checkpoint_utils, utils, options, tasks
from fairseq.logging import progress_bar
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
import ogb
import sys
import os
from pathlib import Path
from sklearn.metrics import roc_auc_score
from ogb.lsc import PCQM4Mv2Evaluator
import sys
from os import path
sys.path.append( path.dirname( path.dirname( path.abspath(__file__) ) ) )
import logging
def eval(args, use_pretrained, checkpoint_path=None, logger=None):
cfg = convert_namespace_to_omegaconf(args)
np.random.seed(cfg.common.seed)
utils.set_torch_seed(cfg.common.seed)
# initialize task
task = tasks.setup_task(cfg.task)
model = task.build_model(cfg.model)
# load checkpoint
model_state = torch.load(checkpoint_path)["model"]
model.load_state_dict(
model_state, strict=True, model_cfg=cfg.model
)
del model_state
model.to(torch.cuda.current_device())
# load dataset
split = args.split
task.load_dataset(split)
batch_iterator = task.get_batch_iterator(
dataset=task.dataset(split),
max_tokens=cfg.dataset.max_tokens_valid,
max_sentences=cfg.dataset.batch_size_valid,
max_positions=utils.resolve_max_positions(
task.max_positions(),
model.max_positions(),
),
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
seed=cfg.common.seed,
num_workers=cfg.dataset.num_workers,
epoch=0,
data_buffer_size=cfg.dataset.data_buffer_size,
disable_iterator_cache=False,
)
itr = batch_iterator.next_epoch_itr(
shuffle=False, set_dataset_epoch=False
)
progress = progress_bar.progress_bar(
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple")
)
# infer
y_pred = []
y_true = []
with torch.no_grad():
model.eval()
for i, sample in enumerate(progress):
sample = utils.move_to_cuda(sample)
y = model(**sample["net_input"])[0][:, 0, :].reshape(-1)
y_pred.extend(y.detach().cpu())
y_true.extend(sample["target"].detach().cpu().reshape(-1)[:y.shape[0]])
torch.cuda.empty_cache()
# save predictions
y_pred = torch.Tensor(y_pred)
y_true = torch.Tensor(y_true)
# evaluate pretrained models
if args.metric == "auc":
auc = roc_auc_score(y_true, y_pred)
logger.info(f"auc: {auc}")
return auc
elif args.metric == "mae":
evaluator = PCQM4Mv2Evaluator()
mae = np.nan
if args.split == 'valid':
input_dict = {'y_pred': y_pred, 'y_true': y_true}
result_dict = evaluator.eval(input_dict)
logger.info(f"mae: {result_dict['mae']}")
mae = result_dict['mae']
else:
input_dict = {'y_pred': y_pred}
evaluator.save_test_submission(input_dict = input_dict, dir_path = os.path.dirname(checkpoint_path), mode = args.split)
np.savez(str(checkpoint_path)[:-3] + f'_{args.split}_{mae:.5f}.npz', y_true=y_true.numpy(), y_pred=y_pred.numpy())
return mae
else:
raise ValueError(f"Unsupported metric {args.metric}")
def main():
parser = options.get_training_parser()
parser.add_argument(
"--split",
type=str,
)
parser.add_argument(
"--metric",
type=str,
)
args = options.parse_args_and_arch(parser, modify_parser=None)
logger = logging.getLogger(__name__)
for checkpoint_fname in os.listdir(args.save_dir):
checkpoint_path = Path(args.save_dir) / checkpoint_fname
if str(checkpoint_path)[-3:] == '.pt':
logger.info(f"evaluating checkpoint file {checkpoint_path}")
eval(args, False, checkpoint_path, logger)
if __name__ == '__main__':
main()