Skip to content

Commit

Permalink
moved autotest changes to autopipeline and addede a few CLI args
Browse files Browse the repository at this point in the history
  • Loading branch information
fishingguy456 committed Jun 17, 2022
1 parent 794f27d commit d8ff8d7
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 24 deletions.
254 changes: 230 additions & 24 deletions imgtools/autopipeline.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
import os, pathlib, sys
from aifc import Error
from distutils.log import warn
import os, pathlib
import shutil
import glob
import pickle
import struct
import numpy as np
import sys
import warnings

from argparse import ArgumentParser
import yaml
import SimpleITK as sitk

from imgtools.ops import StructureSetToSegmentation, ImageAutoInput, ImageAutoOutput, Resample
from imgtools.pipeline import Pipeline
from imgtools.utils import parser
from imgtools.utils.nnunetutils import generate_dataset_json
from imgtools.utils.args import parser
from joblib import Parallel, delayed
from imgtools.modules import Segmentation
from torch import sparse_coo_tensor
from sklearn.model_selection import train_test_split

from imgtools.io.common import file_name_convention
###############################################################
# Example usage:
# python radcure_simple.py ./data/RADCURE/data ./RADCURE_output
Expand All @@ -32,19 +44,106 @@ def __init__(self,
visualize=False,
missing_strategy="drop",
show_progress=False,
warn_on_error=False):
warn_on_error=False,
overwrite=False,
is_nnunet=False,
train_size=1.0,
random_state=42,
read_yaml_label_names=False,
ignore_missing_regex=False):
"""Initialize the pipeline.
Parameters
----------
input_directory: str
Directory containing the input data
output_directory: str
Directory where the output data will be stored
modalities: str, default="CT"
Modalities to load. Can be a comma-separated list of modalities with no spaces
spacing: tuple of floats, default=(1., 1., 0.)
Spacing of the output image
n_jobs: int, default=-1
Number of jobs to run in parallel. If -1, use all cores
visualize: bool, default=False
Whether to visualize the results of the pipeline using pyvis. Outputs to an HTML file.
missing_strategy: str, default="drop"
How to handle missing modalities. Can be "drop" or "fill"
show_progress: bool, default=False
Whether to show progress bars
warn_on_error: bool, default=False
Whether to warn on errors
overwrite: bool, default=False
Whether to write output files even if existing output files exist
is_nnunet: bool, default=False
Whether to format the output for nnunet
train_size: float, default=1.0
Proportion of the dataset to use for training, as a decimal
random_state: int, default=42
Random state for train_test_split
read_yaml_label_names: bool, default=False
Whether to read dictionary representing the label that regexes are mapped to from YAML. For example, "GTV": "GTV.*" will combine all regexes that match "GTV.*" into "GTV"
ignore_missing_regex: bool, default=False
Whether to ignore missing regexes. Will raise an error if none of the regexes in label_names are found for a patient.
"""
super().__init__(
n_jobs=n_jobs,
missing_strategy=missing_strategy,
show_progress=show_progress,
warn_on_error=warn_on_error)

self.overwrite = overwrite
# pipeline configuration
self.input_directory = input_directory
self.output_directory = output_directory
self.input_directory = pathlib.Path(input_directory).as_posix()
self.output_directory = pathlib.Path(output_directory).as_posix()
self.spacing = spacing
self.existing = [None] #self.existing_patients()
self.is_nnunet = is_nnunet
if is_nnunet:
self.nnunet_info = {}
else:
self.nnunet_info = None
self.train_size = train_size
self.random_state = random_state
self.label_names = {}
self.ignore_missing_regex = ignore_missing_regex

with open(pathlib.Path(self.input_directory, "roi_names.yaml").as_posix(), "r") as f:
try:
self.label_names = yaml.safe_load(f)
except yaml.YAMLError as exc:
print(exc)

if not isinstance(self.label_names, dict):
raise ValueError("roi_names.yaml must parse as a dictionary")

for k, v in self.label_names.items():
if not isinstance(v, list) and not isinstance(v, str):
raise ValueError(f"Label values must be either a list of strings or a string. Got {v} for {k}")
elif isinstance(v, list):
for a in v:
if not isinstance(a, str):
raise ValueError(f"Label values must be either a list of strings or a string. Got {a} in list {v} for {k}")
elif not isinstance(k, str):
raise ValueError(f"Label names must be a string. Got {k} for {v}")

if self.train_size == 1.0:
warnings.warn("Train size is 1, all data will be used for training")

if self.train_size == 0.0:
warnings.warn("Train size is 0, all data will be used for testing")

if self.train_size != 1 and not self.is_nnunet:
warnings.warn("Cannot run train/test split without nnunet, ignoring train_size")

if self.train_size > 1 or self.train_size < 0 and self.is_nnunet:
raise ValueError("train_size must be between 0 and 1")

if is_nnunet and (not read_yaml_label_names or self.label_names == {}):
raise ValueError("YAML label names must be provided for nnunet")


if self.is_nnunet:
self.nnunet_info["modalities"] = {"CT": "0000"} #modality to 4-digit code

#input operations
self.input = ImageAutoInput(input_directory, modalities, n_jobs, visualize)
Expand All @@ -57,14 +156,16 @@ def __init__(self,

# image processing ops
self.resample = Resample(spacing=self.spacing)
self.make_binary_mask = StructureSetToSegmentation(roi_names=[], continuous=False)
self.make_binary_mask = StructureSetToSegmentation(roi_names=self.label_names, continuous=False) # "GTV-.*"

# output ops
self.output = ImageAutoOutput(self.output_directory, self.output_streams)
self.output = ImageAutoOutput(self.output_directory, self.output_streams, self.nnunet_info)

#Make a directory
if not os.path.exists(pathlib.Path(self.output_directory,".temp").as_posix()):
os.mkdir(pathlib.Path(self.output_directory,".temp").as_posix())

self.existing_roi_names = {"background": 0}


def process_one_subject(self, subject_id):
Expand All @@ -74,6 +175,7 @@ def process_one_subject(self, subject_id):
multiple images, structures, etc.). During pipeline execution, this
method will receive one argument, subject_id, which can be used to
retrieve inputs and save outputs.
Parameters
----------
subject_id : str
Expand All @@ -92,9 +194,13 @@ def process_one_subject(self, subject_id):
print(subject_id, " start")

metadata = {}
subject_modalities = set() # all the modalities that this subject has
num_rtstructs = 0

for i, colname in enumerate(self.output_streams):
modality = colname.split("_")[0]

subject_modalities.add(modality) #set add

# Taking modality pairs if it exists till _{num}
output_stream = ("_").join([item for item in colname.split("_") if item.isnumeric()==False])

Expand All @@ -108,7 +214,7 @@ def process_one_subject(self, subject_id):
print("The subject id: {} has no {}".format(subject_id, colname))
pass
elif modality == "CT" or modality == 'MR':
image = read_results[i]
image = read_results[i].image
if len(image.GetSize()) == 4:
assert image.GetSize()[-1] == 1, f"There is more than one volume in this CT file for {subject_id}."
extractor = sitk.ExtractImageFilter()
Expand All @@ -118,9 +224,29 @@ def process_one_subject(self, subject_id):
image = extractor.Execute(image)
print(image.GetSize())
image = self.resample(image)
#Saving the output
self.output(subject_id, image, output_stream)

#update the metadata for this image
if hasattr(read_results[i], "metadata") and read_results[i].metadata is not None:
metadata.update(read_results[i].metadata)

#modality is MR and the user has selected to have nnunet output
if self.is_nnunet:
if modality == "MR": #MR images can have various modalities like FLAIR, T1, etc.
self.nnunet_info['current_modality'] = metadata["AcquisitionContrast"]
if not metadata["AcquisitionContrast"] in self.nnunet_info["modalities"].keys(): #if the modality is new
self.nnunet_info["modalities"][metadata["AcquisitionContrast"]] = str(len(self.nnunet_info["modalities"])).zfill(4) #fill to 4 digits
else:
self.nnunet_info['current_modality'] = modality #CT
if "_".join(subject_id.split("_")[1::]) in self.train:
self.output(subject_id, image, output_stream, nnunet_info=self.nnunet_info)
else:
self.output(subject_id, image, output_stream, nnunet_info=self.nnunet_info, train_or_test="Ts")
else:
self.output(subject_id, image, output_stream)

metadata[f"size_{output_stream}"] = str(image.GetSize())


print(subject_id, " SAVED IMAGE")
elif modality == "RTDOSE":
try: #For cases with no image present
Expand All @@ -136,25 +262,65 @@ def process_one_subject(self, subject_id):
self.output(f"{subject_id}_{num}", doses, output_stream)
metadata[f"size_{output_stream}"] = str(doses.GetSize())
metadata[f"metadata_{colname}"] = [read_results[i].get_metadata()]

if hasattr(doses, "metadata") and doses.metadata is not None:
metadata.update(doses.metadata)

print(subject_id, " SAVED DOSE")
elif modality == "RTSTRUCT":
num_rtstructs += 1
#For RTSTRUCT, you need image or PT
structure_set = read_results[i]
conn_to = output_stream.split("_")[-1]

# make_binary_mask relative to ct/pet
if conn_to == "CT" or conn_to == "MR":
mask = self.make_binary_mask(structure_set, image)
mask = self.make_binary_mask(structure_set, image, self.existing_roi_names, self.ignore_missing_regex)
elif conn_to == "PT":
mask = self.make_binary_mask(structure_set, pet)
mask = self.make_binary_mask(structure_set, pet, self.existing_roi_names, self.ignore_missing_regex)
else:
raise ValueError("You need to pass a reference CT or PT/PET image to map contours to.")

if mask is None: #ignored the missing regex
return

for name in mask.roi_names.keys():
if name not in self.existing_roi_names.keys():
self.existing_roi_names[name] = len(self.existing_roi_names)
mask.existing_roi_names = self.existing_roi_names

# save output
if not mult_conn:
self.output(subject_id, mask, output_stream)
print(mask.GetSize())
mask_arr = np.transpose(sitk.GetArrayFromImage(mask))

if self.is_nnunet:
sparse_mask = mask.generate_sparse_mask().mask_array
sparse_mask = sitk.GetImageFromArray(sparse_mask) #convert the nparray to sitk image
if "_".join(subject_id.split("_")[1::]) in self.train:
self.output(subject_id, sparse_mask, output_stream, nnunet_info=self.nnunet_info, label_or_image="labels") #rtstruct is label for nnunet
else:
self.output(subject_id, sparse_mask, output_stream, nnunet_info=self.nnunet_info, label_or_image="labels", train_or_test="Ts")
else:
self.output(f"{subject_id}_{num}", mask, output_stream)
# if there is only one ROI, sitk.GetArrayFromImage() will return a 3d array instead of a 4d array with one slice
if len(mask_arr.shape) == 3:
mask_arr = mask_arr.reshape(1, mask_arr.shape[0], mask_arr.shape[1], mask_arr.shape[2])

print(mask_arr.shape)
roi_names_list = list(mask.roi_names.keys())
for i in range(mask_arr.shape[0]):
new_mask = sitk.GetImageFromArray(np.transpose(mask_arr[i]))
new_mask.CopyInformation(mask)
new_mask = Segmentation(new_mask)
mask_to_process = new_mask
if not mult_conn:
# self.output(roi_names_list[i], mask_to_process, output_stream)
self.output(subject_id, mask_to_process, output_stream, True, roi_names_list[i])
else:
self.output(f"{subject_id}_{num}", mask_to_process, output_stream, True, roi_names_list[i])

if hasattr(structure_set, "metadata") and structure_set.metadata is not None:
metadata.update(structure_set.metadata)

metadata[f"metadata_{colname}"] = [structure_set.roi_names]

print(subject_id, "SAVED MASK ON", conn_to)
Expand All @@ -172,8 +338,17 @@ def process_one_subject(self, subject_id):
self.output(f"{subject_id}_{num}", pet, output_stream)
metadata[f"size_{output_stream}"] = str(pet.GetSize())
metadata[f"metadata_{colname}"] = [read_results[i].get_metadata()]

if hasattr(pet, "metadata") and pet.metadata is not None:
metadata.update(pet.metadata)

print(subject_id, " SAVED PET")

metadata[f"output_folder_{colname}"] = pathlib.Path(subject_id, colname).as_posix()
#Saving all the metadata in multiple text files
metadata["Modalities"] = str(list(subject_modalities))
metadata["numRTSTRUCTs"] = num_rtstructs
metadata["Train or Test"] = "train" if "_".join(subject_id.split("_")[1::]) in self.train else "test"
with open(pathlib.Path(self.output_directory,".temp",f'{subject_id}.pkl').as_posix(),'wb') as f:
pickle.dump(metadata,f)
return
Expand All @@ -186,8 +361,24 @@ def save_data(self):
with open(file,"rb") as f:
metadata = pickle.load(f)
self.output_df.loc[subject_id, list(metadata.keys())] = list(metadata.values())
folder_renames = {}
for col in self.output_df.columns:
if col.startswith("folder"):
self.output_df[col] = self.output_df[col].apply(lambda x: pathlib.Path(x).as_posix().split(self.input_directory)[1][1:]) # rel path, exclude the slash at the beginning
folder_renames[col] = f"input_{col}"
self.output_df.rename(columns=folder_renames, inplace=True) #append input_ to the column name
self.output_df.to_csv(self.output_df_path)
shutil.rmtree(pathlib.Path(self.output_directory, ".temp").as_posix())
if self.is_nnunet:
imagests_path = pathlib.Path(self.output_directory, "imagesTs").as_posix()
images_test_location = imagests_path if os.path.exists(imagests_path) else None
generate_dataset_json(pathlib.Path(self.output_directory, "dataset.json").as_posix(),
pathlib.Path(self.output_directory, "imagesTr").as_posix(),
images_test_location,
tuple(self.nnunet_info["modalities"].keys()),
{v:k for k, v in self.existing_roi_names.items()},
os.path.split(self.input_directory)[1])


def run(self):
"""Execute the pipeline, possibly in parallel.
Expand All @@ -196,27 +387,42 @@ def run(self):
verbose = 51 if self.show_progress else 0

subject_ids = self._get_loader_subject_ids()
patient_ids = []
for subject_id in subject_ids:
if subject_id.split("_")[1::] not in patient_ids:
patient_ids.append("_".join(subject_id.split("_")[1::]))
if self.is_nnunet:
self.train, self.test = train_test_split(patient_ids, train_size=self.train_size, random_state=self.random_state)
else:
self.train, self.test = [], []
# Note that returning any SimpleITK object in process_one_subject is
# not supported yet, since they cannot be pickled
if os.path.exists(self.output_df_path):
if os.path.exists(self.output_df_path) and not self.overwrite:
print("Dataset already processed...")
shutil.rmtree(pathlib.Path(self.output_directory, ".temp").as_posix())
else:
Parallel(n_jobs=self.n_jobs, verbose=verbose)(
delayed(self._process_wrapper)(subject_id) for subject_id in subject_ids)
# Parallel(n_jobs=self.n_jobs, verbose=verbose)(
# delayed(self._process_wrapper)(subject_id) for subject_id in subject_ids)
for subject_id in subject_ids:
self._process_wrapper(subject_id)
self.save_data()

def main():
args = parser()
if args.nnunet_study_name:
nnunet_info = {"study name": args.nnunet_study_name}
pipeline = AutoPipeline(args.input_directory,
args.output_directory,
modalities=args.modalities,
visualize=args.visualize,
spacing=args.spacing,
n_jobs=args.n_jobs,
visualize=args.visualize,
show_progress=args.show_progress,
nnunet_info=nnunet_info)
warn_on_error=args.warn_on_error,
overwrite=args.overwrite,
is_nnunet=args.nnunet,
train_size=args.train_size,
random_state=args.random_state,
read_yaml_label_names=args.read_yaml_label_names,
ignore_missing_regex=args.ignore_missing_regex)

print(f'starting Pipeline...')
pipeline.run()
Expand Down
Loading

0 comments on commit d8ff8d7

Please sign in to comment.