Skip to content

Commit

Permalink
dataset.json for nnunet
Browse files Browse the repository at this point in the history
Former-commit-id: 156553f
  • Loading branch information
fishingguy456 committed Jun 16, 2022
1 parent d906330 commit a9f0e87
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 4 deletions.
28 changes: 26 additions & 2 deletions examples/autotest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from imgtools.ops import StructureSetToSegmentation, ImageAutoInput, ImageAutoOutput, Resample
from imgtools.pipeline import Pipeline
from imgtools.utils.nnunetutils import generate_dataset_json
from joblib import Parallel, delayed
from imgtools.modules import Segmentation
from torch import sparse_coo_tensor
Expand Down Expand Up @@ -109,6 +110,19 @@ def __init__(self,
self.label_names = yaml.safe_load(f)
except yaml.YAMLError as exc:
print(exc)

if not isinstance(self.label_names, dict):
raise ValueError("roi_names.yaml must parse as a dictionary")

for k, v in self.label_names.items():
if not isinstance(v, list) and not isinstance(v, str):
raise ValueError(f"Label values must be either a list of strings or a string. Got {v} for {k}")
elif isinstance(v, list):
for a in v:
if not isinstance(a, str):
raise ValueError(f"Label values must be either a list of strings or a string. Got {a} in list {v} for {k}")
elif not isinstance(k, str):
raise ValueError(f"Label names must be a string. Got {k} for {v}")

if self.train_size == 1.0:
warnings.warn("Train size is 1, all data will be used for training")
Expand Down Expand Up @@ -149,7 +163,7 @@ def __init__(self,
if not os.path.exists(pathlib.Path(self.output_directory,".temp").as_posix()):
os.mkdir(pathlib.Path(self.output_directory,".temp").as_posix())

self.existing_roi_names = {}
self.existing_roi_names = {"background": 0}


def process_one_subject(self, subject_id):
Expand Down Expand Up @@ -353,6 +367,16 @@ def save_data(self):
self.output_df.rename(columns=folder_renames, inplace=True) #append input_ to the column name
self.output_df.to_csv(self.output_df_path)
shutil.rmtree(pathlib.Path(self.output_directory, ".temp").as_posix())
if self.is_nnunet:
imagests_path = pathlib.Path(self.output_directory, "imagesTs").as_posix()
images_test_location = imagests_path if os.path.exists(imagests_path) else None
generate_dataset_json(pathlib.Path(self.output_directory, "dataset.json").as_posix(),
pathlib.Path(self.output_directory, "imagesTr").as_posix(),
images_test_location,
tuple(self.nnunet_info["modalities"].keys()),
{v:k for k, v in self.existing_roi_names.items()},
os.path.split(self.input_directory)[1])


def run(self):
"""Execute the pipeline, possibly in parallel.
Expand Down Expand Up @@ -406,7 +430,7 @@ def run(self):
modalities="CT,RTSTRUCT",
visualize=False,
overwrite=True,
# is_nnunet=True,
is_nnunet=True,
train_size=0.5,
# label_names={"GTV":"GTV.*", "Brainstem": "Brainstem.*"},
read_yaml_label_names=True, # "GTV.*",
Expand Down
4 changes: 2 additions & 2 deletions imgtools/modules/structureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def to_segmentation(self, reference_image: sitk.Image,
if isinstance(pattern, str):
matching_names = list(self._assign_labels([pattern], force_missing).keys())
if matching_names:
labels[name] = matching_names #{"GTV": ["GTV1", "GTV2"]}
elif isinstance(pattern, list):
labels[name] = matching_names #{"GTV": ["GTV1", "GTV2"]} is the result of _assign_labels()
elif isinstance(pattern, list): # for inputs that have multiple patterns for the input, e.g. {"GTV": ["GTV.*", "HTVI.*"]}
labels[name] = []
for pat in pattern:
matching_names = list(self._assign_labels([pat], force_missing).keys())
Expand Down
1 change: 1 addition & 0 deletions imgtools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .crawl import *
from .dicomutils import *
from .args import *
from .nnunetutils import *
3 changes: 3 additions & 0 deletions imgtools/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def parser():
parser.add_argument("--random_state", type=int, default=42,
help="The random state to be used for the train-test-split.")

parser.add_argument("--read_yaml_label_names", default=False, action="store_true",
help="Whether to read the label names from roi_names.yaml in the input directory.")

parser.add_argument("--ignore_missing_roi_regex", default=False, action="store_true",
help="Whether to ignore patients with no ROI regex's that match the given ones. Will throw an error on patients without matches if this is not set.")

Expand Down
79 changes: 79 additions & 0 deletions imgtools/utils/nnunetutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from typing import Tuple, List
import os, pathlib
import json
import numpy as np

# this code is taken from:
# Division of Medical Image Computing, German Cancer Research Center (DKFZ)
# in the nnUNet and batchgenerator repositories

def save_json(obj, file: str, indent: int = 4, sort_keys: bool = True) -> None:
with open(file, 'w') as f:
json.dump(obj, f, sort_keys=sort_keys, indent=indent)

def get_identifiers_from_splitted_files(folder: str):
uniques = np.unique([i[:-12] for i in subfiles(folder, suffix='.nii.gz', join=False)])
return uniques

def subfiles(folder: str, join: bool = True, prefix: str = None, suffix: str = None, sort: bool = True) -> List[str]:
if join:
l = os.path.join
else:
l = lambda x, y: y
res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i))
and (prefix is None or i.startswith(prefix))
and (suffix is None or i.endswith(suffix))]
if sort:
res.sort()
return res

def generate_dataset_json(output_file: str, imagesTr_dir: str, imagesTs_dir: str, modalities: Tuple,
labels: dict, dataset_name: str, sort_keys=True, license: str = "hands off!", dataset_description: str = "",
dataset_reference="", dataset_release='0.0'):
"""
:param output_file: This needs to be the full path to the dataset.json you intend to write, so
output_file='DATASET_PATH/dataset.json' where the folder DATASET_PATH points to is the one with the
imagesTr and labelsTr subfolders
:param imagesTr_dir: path to the imagesTr folder of that dataset
:param imagesTs_dir: path to the imagesTs folder of that dataset. Can be None
:param modalities: tuple of strings with modality names. must be in the same order as the images (first entry
corresponds to _0000.nii.gz, etc). Example: ('T1', 'T2', 'FLAIR').
:param labels: dict with int->str (key->value) mapping the label IDs to label names. Note that 0 is always
supposed to be background! Example: {0: 'background', 1: 'edema', 2: 'enhancing tumor'}
:param dataset_name: The name of the dataset. Can be anything you want
:param sort_keys: In order to sort or not, the keys in dataset.json
:param license:
:param dataset_description:
:param dataset_reference: website of the dataset, if available
:param dataset_release:
:return:
"""
train_identifiers = get_identifiers_from_splitted_files(imagesTr_dir)

if imagesTs_dir is not None:
test_identifiers = get_identifiers_from_splitted_files(imagesTs_dir)
else:
test_identifiers = []

json_dict = {}
json_dict['name'] = dataset_name
json_dict['description'] = dataset_description
json_dict['tensorImageSize'] = "4D"
json_dict['reference'] = dataset_reference
json_dict['licence'] = license
json_dict['release'] = dataset_release
json_dict['modality'] = {str(i): modalities[i] for i in range(len(modalities))}
json_dict['labels'] = {str(i): labels[i] for i in labels.keys()}

json_dict['numTraining'] = len(train_identifiers)
json_dict['numTest'] = len(test_identifiers)
json_dict['training'] = [
{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i
in
train_identifiers]
json_dict['test'] = ["./imagesTs/%s.nii.gz" % i for i in test_identifiers]

if not output_file.endswith("dataset.json"):
print("WARNING: output file name is not dataset.json! This may be intentional or not. You decide. "
"Proceeding anyways...")
save_json(json_dict, os.path.join(output_file), sort_keys=sort_keys)

0 comments on commit a9f0e87

Please sign in to comment.