-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_kernelshap.py
121 lines (100 loc) · 4.03 KB
/
generate_kernelshap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
from argparse import ArgumentParser
from pathlib import Path
import numpy as np
from captum.attr import KernelShap
from kernelshap_utils import (
B50_PATHS,
DEVICE,
Wrapper,
create_cubical_mask,
load_model,
output_segmentation,
read_path,
set_seeds,
)
from tqdm import tqdm
def main():
parser = ArgumentParser()
parser.add_argument("--dataset", type=str, choices=["b50"], default="b50")
parser.add_argument("--partition", type=int)
parser.add_argument("--num_partitions", type=int)
parser.add_argument("--save_folder", type=str, required=True)
parser.add_argument("--predictions_save_folder", type=Path, default=None)
parser.add_argument("--skip_already_processed", action="store_true")
parser.add_argument("--num_components", type=int, default=512)
parser.add_argument("--targets", nargs="+", type=int, required=True)
arguments = parser.parse_args()
partition = arguments.partition
save_folder = Path(arguments.save_folder)
if partition is None:
partition = os.environ.get("SLURM_ARRAY_TASK_ID")
if partition is not None:
partition = int(partition)
if partition is not None and arguments.num_partitions is None:
raise ValueError("Need to specify num_partitions when specifying partition")
if arguments.dataset == "b50":
dataset_scans = B50_PATHS
else:
raise ValueError(f"dataset {arguments.dataset} not supported")
if partition is not None:
dataset_scans = dataset_scans[partition : partition + 1]
print(f"Device: {DEVICE}")
set_seeds()
model, test_transform = load_model()
for path in tqdm(dataset_scans):
patient_path = "/".join(path.split("/")[-3:])
save_path = save_folder / patient_path
save_path.mkdir(parents=True, exist_ok=True)
image_path = read_path(path)
loaded_img = test_transform(image_path).to(DEVICE).unsqueeze(0)
os.remove(image_path)
cube_mask = create_cubical_mask(
loaded_img, num_components=arguments.num_components
).to(DEVICE)
out_max = output_segmentation(model, loaded_img).to(DEVICE)
wrapper = Wrapper(model, out_max)
ks_wrapper = KernelShap(wrapper.wrapper_classes_intersection)
if arguments.predictions_save_folder is not None:
predictions_save_path = arguments.predictions_save_folder / patient_path
if not (predictions_save_path / "output_segmentation.npy").exists():
predictions_save_path.mkdir(parents=True, exist_ok=True)
np.save(
predictions_save_path / "output_segmentation.npy",
out_max.cpu().numpy(),
)
for target in arguments.targets:
if (
not arguments.skip_already_processed
or not (save_path / f"kernelshap_cubes_{target}.npy").exists()
):
ks_attr_cubes = ks_wrapper.attribute(
loaded_img,
feature_mask=cube_mask,
perturbations_per_eval=1,
target=target,
n_samples=1000,
show_progress=True,
)
np.save(
save_path / f"kernelshap_cubes_{target}.npy",
ks_attr_cubes.cpu().numpy(),
)
if (
not arguments.skip_already_processed
or not (save_path / f"kernelshap_segmentations_{target}.npy").exists()
):
ks_attr_segmentations = ks_wrapper.attribute(
loaded_img,
feature_mask=out_max,
perturbations_per_eval=1,
target=target,
n_samples=200,
show_progress=True,
)
np.save(
save_path / f"kernelshap_segmentations_{target}.npy",
ks_attr_segmentations.cpu().numpy(),
)
if __name__ == "__main__":
main()