-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_specific_data.py
181 lines (156 loc) · 8.83 KB
/
test_specific_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.run.default_configuration import get_default_configuration_with_multiTask
from nnunet.paths import default_plans_identifier
from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import nnUNetTrainerCascadeFullRes
from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
from nnunet.paths import *
def main():
parser = argparse.ArgumentParser()
parser.add_argument("network")
# parser.add_argument("network_trainer")
parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
parser.add_argument("gpu", help='0, 1, ..., 5 or \'all\'')
parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
action="store_true")
parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
action="store_true")
parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
default=default_plans_identifier, required=False)
parser.add_argument("--use_compressed_data", default=False, action="store_true",
help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
"is much more CPU and RAM intensive and should only be used if you know what you are "
"doing", required=False)
parser.add_argument("--deterministic",
help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
"this is not necessary. Deterministic training will make you overfit to some random seed. "
"Don't use that.",
required=False, default=False, action="store_true")
parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will "
"export npz files of "
"predicted segmentations "
"in the validation as well. "
"This is needed to run the "
"ensembling step so unless "
"you are developing nnUNet "
"you should enable this")
parser.add_argument("--find_lr", required=False, default=False, action="store_true",
help="not used here, just for fun")
parser.add_argument("--valbest", required=False, default=False, action="store_true",
help="hands off. This is not intended to be used")
parser.add_argument("--fp32", required=False, default=False, action="store_true",
help="disable mixed precision training and run old school fp32")
parser.add_argument("--val_folder", required=False, default="validation_raw",
help="name of the validation folder. No need to use this for most people")
parser.add_argument("--interp_order", required=False, default=3, type=int,
help="order of interpolation for segmentations. Testing purpose only. Hands off")
parser.add_argument("--interp_order_z", required=False, default=0, type=int,
help="order of interpolation along z if z is resampled separately. Testing purpose only. "
"Hands off")
parser.add_argument("--force_separate_z", required=False, default="None", type=str,
help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
tasks = ["100", "101", "102", "103", "104"]
fold = args.fold
network = args.network
network_trainer = "nnUNetMultiTrainerV2"
validation_only = args.validation_only
plans_identifier = args.p
find_lr = args.find_lr
use_compressed_data = args.use_compressed_data
decompress_data = not use_compressed_data
deterministic = args.deterministic
valbest = args.valbest
fp32 = args.fp32
run_mixed_precision = not fp32
val_folder = args.val_folder
interp_order = args.interp_order
interp_order_z = args.interp_order_z
force_separate_z = args.force_separate_z
classes_dict = {}
for i, task in enumerate(tasks):
if not task.startswith("Task"):
task_id = int(task)
task = convert_id_to_task_name(task_id)
tasks[i] = task
json_file = join(preprocessing_output_dir,task, "dataset.json")
classes = []
with open(json_file) as jsn:
d = json.load(jsn)
tags = d['labels']
for i in tags:
if not int(i) == 0:
classes.append(tags[i])
classes_dict[task] = classes
if fold == 'all':
pass
else:
fold = int(fold)
if force_separate_z == "None":
force_separate_z = None
elif force_separate_z == "False":
force_separate_z = False
elif force_separate_z == "True":
force_separate_z = True
else:
raise ValueError(
"force_separate_z must be None, True or False. Given: %s" % force_separate_z)
plans_file, output_folder_names, dataset_directorys, batch_dice, stage, \
trainer_class = get_default_configuration_with_multiTask(
network, tasks, network_trainer, plans_identifier)
if trainer_class is None:
raise RuntimeError(
"Could not find trainer class in nnunet.training.network_training")
if network == "3d_cascade_fullres":
assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
"If running 3d_cascade_fullres then your " \
"trainer class must be derived from " \
"nnUNetTrainerCascadeFullRes"
else:
assert issubclass(trainer_class,
nnUNetTrainer), "network_trainer was found but is not derived from nnUNetMultiTrainer"
trainer = trainer_class(plans_file, fold,tasks,tags=classes_dict, output_folder_dict=output_folder_names, dataset_directory_dict=dataset_directorys,
batch_dice=batch_dice, stage=stage, unpack_data=decompress_data,
deterministic=deterministic,
fp16=run_mixed_precision)
trainer.initialize(not validation_only)
if find_lr:
trainer.find_lr()
else:
if not validation_only:
if args.continue_training:
trainer.load_latest_checkpoint()
trainer.run_training()
else:
if valbest:
trainer.load_best_checkpoint(train=False)
else:
trainer.load_latest_checkpoint(train=False)
trainer.network.eval()
# predict validation
trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder, force_separate_z=force_separate_z,
interpolation_order=interp_order, interpolation_order_z=interp_order_z)
# if network == '3d_lowres':
# trainer.load_best_checkpoint(False)
# print("predicting segmentations for the next stage of the cascade")
# predict_next_stage(trainer, join(
# dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1))
if __name__ == "__main__":
main()