-
Notifications
You must be signed in to change notification settings - Fork 36
/
eval.py
60 lines (43 loc) · 1.62 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
FFG-benchmarks
Copyright (c) 2021-present NAVER Corp.
MIT license
"""
import os
import argparse
import json
from pathlib import Path
from sconf import Config
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from evaluator.test import run, load_checkpoint
from evaluator.dataset import EvalTestDataset
from train_evaluator import transform
cudnn.benchmark = True
def setup_dset(cfg):
keys = json.load(open(cfg.dset.test.keylist))
cfg.dset.test.keylist = keys
cfg.n_styles = len(keys)
chars = json.load(open(cfg.dset.test.charlist))
cfg.dset.test.charlist = chars
cfg.n_chars = len(chars)
return cfg
def main():
parser = argparse.ArgumentParser()
parser.add_argument("config_path", nargs="+", help="path/to/config.yaml")
parser.add_argument("--result_dir", help="path/to/save/result")
parser.add_argument("--result_name", help="Filename of result file")
parser.add_argument("--verbose", type=bool, default=True)
args, left_argv = parser.parse_known_args()
cfg = Config(*args.config_path)
cfg = setup_dset(cfg)
Path(args.result_dir).mkdir(exist_ok=True, parents=True)
test_dataset = EvalTestDataset(**cfg.dset.test, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
model_style, model_content = load_checkpoint(cfg)
res_dict = run(test_dataloader, model_style, model_content, args.verbose)
file_path = os.path.join(args.result_dir, f"{args.result_name}.json")
with open(file_path, 'w') as f:
json.dump(res_dict, f)
if __name__ == "__main__":
main()