-
Notifications
You must be signed in to change notification settings - Fork 0
/
ci_evaluate.py
62 lines (50 loc) · 1.83 KB
/
ci_evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import neptune
import torch
from dataset import BrainSegmentationDataset
from model_utils import DiceLoss, UNet
# (neptune) fetch project
project = neptune.init_project(project="common/project-images-segmentation")
# (neptune) find best run
best_run_df = project.fetch_runs_table(tag="best").to_pandas()
best_run_id = best_run_df["sys/id"].values[0]
# (neptune) re-init the chosen run
base_namespace = "evaluate"
ref_run = neptune.init_run(
project="common/project-images-segmentation",
tags=["evaluation"],
source_files=None,
monitoring_namespace=f"{base_namespace}/monitoring",
with_id=best_run_id,
)
ref_run[f"{base_namespace}/validation_data_version"].track_files(
"s3://neptune-examples/data/brain-mri-dataset/evaluation/TCGA_HT_7692_19960724/"
)
ref_run[f"{base_namespace}/validation_data_version"].download(destination="evaluation_data")
valid = BrainSegmentationDataset(
images_dir="evaluation_data",
subset="validation",
random_sampling=False,
seed=ref_run["data/preprocessing_params/seed"].fetch(),
)
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
unet = UNet(
in_channels=BrainSegmentationDataset.in_channels,
out_channels=BrainSegmentationDataset.out_channels,
)
unet.to(device)
# (neptune) Download the weights from the `train` run
ref_run["training/model/model_weight"].download("evaluate_unet.pt")
ref_run.wait()
# Load the downloaded weights
state_dict = torch.load("evaluate_unet.pt", map_location=device)
unet.load_state_dict(state_dict)
loss_fn = DiceLoss()
loss = 0.0
for i in range(len(valid)):
with torch.no_grad():
x, y, fname = valid[i]
x, y = x.unsqueeze(0), y.unsqueeze(0)
y_pred = unet(x)
loss += loss_fn(y_pred, y).item()
# (neptune) log evaluated loss.
ref_run[f"{base_namespace}/mean_evaluation_loss"] = (loss) / len(valid)