-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluate_model.py
70 lines (60 loc) · 1.96 KB
/
evaluate_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import io
import hydra
import torch
from hydra.core.config_store import ConfigStore
from utils.config import *
from utils.data_utils import *
from utils.modeling import *
from utils.train_utils import *
train_device = torch.device("cuda")
store_device = torch.device("cuda")
def define_model(input_size=776):
print("Defining the model")
model = Model_Recursive_LSTM_v2(
input_size=input_size,
comp_embed_layer_sizes=[600, 350, 200, 180],
drops=[0.050] * 5,
train_device="cuda:0",
loops_tensor_size=20,
).to(train_device)
return model
def evaluate(model, dataset_path):
print("Loading the dataset...")
batch = torch.load(dataset_path)
val_ds, val_bl, val_indices = batch
print("Evaluation...")
val_df = get_results_df(val_ds, val_bl, val_indices, model, train_device="cpu")
val_scores = get_scores(val_df)
return dict(
zip(
["nDCG", "nDCG@5", "nDCG@1", "Spearman_ranking_correlation", "MAPE"],
[item for item in val_scores.describe().iloc[1, 1:6].to_numpy()],
)
)
@hydra.main(config_path="conf", config_name="config")
def main(conf):
model = define_model(input_size=776)
model.load_state_dict(
torch.load(
os.path.join(
conf.experiment.base_path,
"weights/",
conf.testing.checkpoint,
),
map_location=train_device,
)
)
for dataset in conf.testing.datasets:
if dataset in ["valid", "bench"]:
print(f"getting results for {dataset}")
dataset_path = os.path.join(
conf.experiment.base_path,
f"dataset/{dataset}",
f"{conf.data_generation.dataset_name}.pt",
)
scores = evaluate(model, dataset_path)
print(scores)
if __name__ == "__main__":
cs = ConfigStore.instance()
cs.store(name="experiment_config", node=RecursiveLSTMConfig)
main()