-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_best_model.py
57 lines (44 loc) · 1.76 KB
/
test_best_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import neptune.new as neptune
import torch
from load_data import load_data
from model import Net
def test_best_model(run, best_trial):
best_trained_model = Net(best_trial.config["l1"], best_trial.config["l2"])
device = "cuda:0" if torch.cuda.is_available() else "cpu"
best_trained_model.to(device)
# (neptune) fetch project
project = neptune.get_project(name="common/project-hpo-with-ray-tune")
# (neptune) find best trial
best_run_df = project.fetch_runs_table(owner="kamil", tag="trial_run").to_pandas()
best_run_df = best_run_df.sort_values(by=["trial/metrics/valid/epoch/loss"])
best_run_id = best_run_df["sys/id"].values[0]
# (neptune) resume this run
best_run = neptune.init(
project="common/project-hpo-with-ray-tune",
run=best_run_id,
mode="read-only",
)
# (neptune) download model and close run
checkpoint_path = "./best_checkpoint"
best_run["trial/checkpoint"].download(checkpoint_path)
best_run.stop()
model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)
_, test_set, _ = load_data(
os.path.abspath("/abs/path/to/data")
)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=4, shuffle=False, num_workers=4)
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = best_trained_model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# (neptune) log test accuracy
run["best/test_accuracy"] = correct / total