-
Notifications
You must be signed in to change notification settings - Fork 3
/
eval.py
43 lines (33 loc) · 1.22 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
The main eval code
"""
import hydra
import torch
from evals.load_evaluators import load_evaluator
from models.build_models import build_model
@hydra.main(config_path="configs", config_name="test")
def main(cfg):
"""run the main eval loop"""
# load checkpoint from the path if there
if "model_ckpt" in cfg:
# set the checkpoint path to absolute path
cfg["model_ckpt"] = hydra.utils.to_absolute_path(cfg["model_ckpt"])
model = build_model(checkpoint=torch.load(cfg["model_ckpt"]))
# otherwise build the model from scratch (e.g. for external pretrained models)
else:
model = build_model(model_cfg=cfg["model"])
model.eval()
# load the evaluator
benchmark_names = cfg["testing"]["benchmarks"]
benchmark_names = [str(benchmark_name) for benchmark_name in benchmark_names]
evaluator = load_evaluator(
evaluator_name=cfg["testing"]["evaluator_name"], model=model, benchmarks=benchmark_names
)
# run the evaluator
results = evaluator.evaluate()
with open(cfg["output_path"], "w") as f:
f.write(str(results))
if __name__ == "__main__":
# pylint: disable=no-value-for-parameter
main()
# pylint: enable=no-value-for-parameter