Skip to content

Commit

Permalink
train test split
Browse files Browse the repository at this point in the history
Former-commit-id: 15bab7b
  • Loading branch information
fishingguy456 committed Jun 13, 2022
1 parent 3ea104f commit 7bc5404
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 60 deletions.
82 changes: 55 additions & 27 deletions examples/autotest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from aifc import Error
import os, pathlib
import shutil
import glob
Expand All @@ -14,6 +15,7 @@
from joblib import Parallel, delayed
from imgtools.modules import Segmentation
from torch import sparse_coo_tensor
from sklearn.model_selection import train_test_split

from imgtools.io.common import file_name_convention
###############################################################
Expand All @@ -39,7 +41,9 @@ def __init__(self,
show_progress=False,
warn_on_error=False,
overwrite=False,
nnUnet_info=None):
nnUnet_info=None,
train_size=1.0,
random_state=42):

super().__init__(
n_jobs=n_jobs,
Expand All @@ -53,9 +57,21 @@ def __init__(self,
self.spacing = spacing
self.existing = [None] #self.existing_patients()
self.nnUnet_info = nnUnet_info
self.train_size = train_size
self.random_state = random_state

if self.train_size == 1 or self.train_size == 0:
raise Warning("No train/test split")

if self.train_size != 1 and not self.nnUnet_info:
raise Warning("Cannot run train/test split without nnUnet, ignoring train_size")

if self.train_size > 1 or self.train_size < 0 and self.nnUnet_info:
raise ValueError("train_size must be between 0 and 1")

if nnUnet_info:
self.nnUnet_info["modalities"] = {"CT": "0000"}
self.nnUnet_info["index"] = 0
self.nnUnet_info["modalities"] = {"CT": "0000"} #modality to 4-digit code
self.nnUnet_info["index"] = 0 #number of patients

#input operations
self.input = ImageAutoInput(input_directory, modalities, n_jobs, visualize)
Expand Down Expand Up @@ -103,13 +119,15 @@ def process_one_subject(self, subject_id):
print(subject_id, " start")

metadata = {}
subject_modalities = set()
subject_modalities = set() # all the modalities that this subject has
num_rtstructs = 0

if self.nnUnet_info:
self.nnUnet_info["index"] += 1
self.nnUnet_info["index"] += 1 #increment the number of patients
for i, colname in enumerate(self.output_streams):
modality = colname.split("_")[0]
subject_modalities.add(modality)
subject_modalities.add(modality) #set add

# Taking modality pairs if it exists till _{num}
output_stream = ("_").join([item for item in colname.split("_") if item.isnumeric()==False])

Expand All @@ -134,18 +152,22 @@ def process_one_subject(self, subject_id):
print(image.GetSize())
image = self.resample(image)

#update the metadata for this image
if hasattr(read_results[i], "metadata") and read_results[i].metadata is not None:
metadata.update(read_results[i].metadata)

#modality is MR and the user has selected to have nnUnet output
if self.nnUnet_info:
if modality == "MR":
if modality == "MR": #MR images can have various modalities like FLAIR, T1, etc.
self.nnUnet_info['current_modality'] = metadata["AcquisitionContrast"]
if not metadata["AcquisitionContrast"] in self.nnUnet_info["modalities"].keys():
self.nnUnet_info["modalities"][metadata["AcquisitionContrast"]] = str(len(self.nnUnet_info["modalities"])).zfill(4)
if not metadata["AcquisitionContrast"] in self.nnUnet_info["modalities"].keys(): #if the modality is new
self.nnUnet_info["modalities"][metadata["AcquisitionContrast"]] = str(len(self.nnUnet_info["modalities"])).zfill(4) #fill to 4 digits
else:
self.nnUnet_info['current_modality'] = modality #CT
self.output(subject_id, image, output_stream, nnUnet_info=self.nnUnet_info)
if subject_id in self.train:
self.output(subject_id, image, output_stream, nnUnet_info=self.nnUnet_info)
else:
self.output(subject_id, image, output_stream, nnUnet_info=self.nnUnet_info, train_or_test="Ts")
else:
self.output(subject_id, image, output_stream)

Expand Down Expand Up @@ -192,12 +214,11 @@ def process_one_subject(self, subject_id):

if self.nnUnet_info:
sparse_mask = mask.generate_sparse_mask().mask_array
sparse_mask = sitk.GetImageFromArray(sparse_mask)
# save_path = pathlib.Path(self.output_directory, subject_id, "sparse_mask", "sparse_mask.nii.gz").as_posix()
self.output(subject_id, sparse_mask, output_stream, nnUnet_info=self.nnUnet_info, label_or_image="labels")
# sparse_mask_nifti = nib.Nifti1Image(sparse_mask.mask_array, affine=np.eye(4))
# nib.save(sparse_mask_nifti, save_path)
# self.output("sparse_mask", sparse_mask, output_stream, "sparse_mask")
sparse_mask = sitk.GetImageFromArray(sparse_mask) #convert the nparray to sitk image
if subject_id in self.train:
self.output(subject_id, sparse_mask, output_stream, nnUnet_info=self.nnUnet_info, label_or_image="labels") #rtstruct is label for nnunet
else:
self.output(subject_id, sparse_mask, output_stream, nnUnet_info=self.nnUnet_info, label_or_image="labels", train_or_test="Ts")
else:
# if there is only one ROI, sitk.GetArrayFromImage() will return a 3d array instead of a 4d array with one slice
if len(mask_arr.shape) == 3:
Expand Down Expand Up @@ -246,7 +267,7 @@ def process_one_subject(self, subject_id):
#Saving all the metadata in multiple text files
metadata["Modalities"] = str(list(subject_modalities))
metadata["numRTSTRUCTs"] = num_rtstructs

metadata["Train or Test"] = "train" if subject_id in self.train else "test"
with open(pathlib.Path(self.output_directory,".temp",f'{subject_id}.pkl').as_posix(),'wb') as f:
pickle.dump(metadata,f)
return
Expand All @@ -264,7 +285,7 @@ def save_data(self):
if col.startswith("folder"):
self.output_df[col] = self.output_df[col].apply(lambda x: pathlib.Path(x).as_posix().split(self.input_directory)[1][1:]) # rel path, exclude the slash at the beginning
folder_renames[col] = f"input_{col}"
self.output_df.rename(columns=folder_renames, inplace=True)
self.output_df.rename(columns=folder_renames, inplace=True) #append input_ to the column name
self.output_df.to_csv(self.output_df_path)
shutil.rmtree(pathlib.Path(self.output_directory, ".temp").as_posix())

Expand All @@ -275,6 +296,9 @@ def run(self):
verbose = 51 if self.show_progress else 0

subject_ids = self._get_loader_subject_ids()
if self.nnUnet_info:
self.num_subjects = len(subject_ids)
self.train, self.test = train_test_split(subject_ids, train_size=self.train_size, random_state=self.random_state)
# Note that returning any SimpleITK object in process_one_subject is
# not supported yet, since they cannot be pickled
if os.path.exists(self.output_df_path) and not self.overwrite:
Expand Down Expand Up @@ -307,17 +331,21 @@ def run(self):
# overwrite=True,
# nnUnet_info={"study name": "NSCLC-Radiomics-Interobserver1"})

# pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/hnscc_testing/HNSCC",
# output_directory="C:/Users/qukev/BHKLAB/hnscc_testing_output",
# modalities="CT,RTSTRUCT",
# visualize=False,
# overwrite=True)

pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/dataset/manifest-1598890146597/NSCLC-Radiomics-Interobserver1",
output_directory="C:/Users/qukev/BHKLAB/autopipelineoutput",
pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/hnscc_testing/HNSCC",
output_directory="C:/Users/qukev/BHKLAB/hnscc_testing_output",
modalities="CT,RTSTRUCT",
visualize=False,
overwrite=True)
overwrite=True,
nnUnet_info={"study name": "TCIA-HNSCC"},
train_size=0.5)

# pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/dataset/manifest-1598890146597/NSCLC-Radiomics-Interobserver1",
# output_directory="C:/Users/qukev/BHKLAB/autopipelineoutput",
# modalities="CT,RTSTRUCT",
# visualize=False,
# overwrite=True,
# nnUnet_info={"study name": "NSCLC-Radiomics-Interobserver1"},
# train_size=0.5)

# pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/hnscc_pet/PET",
# output_directory="C:/Users/qukev/BHKLAB/hnscc_pet_output",
Expand Down
5 changes: 4 additions & 1 deletion imgtools/autopipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,16 @@ def run(self):

def main():
args = parser()
if args.nnunet_study_name:
nnUnet_info = {"study name": args.nnunet_study_name}
pipeline = AutoPipeline(args.input_directory,
args.output_directory,
modalities=args.modalities,
spacing=args.spacing,
n_jobs=args.n_jobs,
visualize=args.visualize,
show_progress=args.show_progress)
show_progress=args.show_progress,
nnUnet_info=nnUnet_info)

print(f'starting Pipeline...')
pipeline.run()
Expand Down
25 changes: 8 additions & 17 deletions imgtools/io/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,36 +54,27 @@ def __init__(self, root_directory, filename_format="{subject_id}.nii.gz", create
if os.path.exists(self.root_directory):
if os.path.basename(os.path.dirname(self.root_directory)) == "{subject_id}":
shutil.rmtree(os.path.dirname(self.root_directory))
elif "{label_or_image}" in os.path.basename(self.root_directory):
elif "{label_or_image}{train_or_test}" in os.path.basename(self.root_directory):
shutil.rmtree(self.root_directory)
#delete the folder called {subject_id} that was made in the original BaseWriter
#delete the folder called {subject_id} that was made in the original BaseWriter / the one named {label_or_image}


def put(self, subject_id, image, is_mask=False, nnUnet_info=None, label_or_image: str = "images", mask_label="", **kwargs):
def put(self, subject_id, image, is_mask=False, nnUnet_info=None, label_or_image: str = "images", mask_label="", train_or_test: str = "Tr", **kwargs):
if is_mask:
self.filename_format = mask_label+".nii.gz"
self.filename_format = mask_label+".nii.gz" #save the mask labels as their rtstruct names
if nnUnet_info:
if label_or_image == "labels":
filename = f"{nnUnet_info['study name']}_{nnUnet_info['index']}.nii.gz"
filename = f"{nnUnet_info['study name']}_{nnUnet_info['index']}.nii.gz" #naming convention for labels
else:
# f"{nnUnet_info['study name']}_{nnUnet_info['index']}_{nnUnet_info['modalities'][nnUnet_info['current_modality']]}.nii.gz"
filename = self.filename_format.format(study_name=nnUnet_info['study name'], index=nnUnet_info['index'], modality_index=nnUnet_info['modalities'][nnUnet_info['current_modality']])
out_path = self._get_path_from_subject_id(filename, label_or_image=label_or_image)
filename = self.filename_format.format(study_name=nnUnet_info['study name'], index=nnUnet_info['index'], modality_index=nnUnet_info['modalities'][nnUnet_info['current_modality']]) #naming convention for images
out_path = self._get_path_from_subject_id(filename, label_or_image=label_or_image, train_or_test=train_or_test)
else:
out_path = self._get_path_from_subject_id(self.filename_format, subject_id=subject_id)
sitk.WriteImage(image, out_path, self.compress)

def _get_path_from_subject_id(self, filename, **kwargs):
# out_filename = self.filename_format.format(subject_id=subject_id, **kwargs)
# print(subject_id, "asasa")
# print(self.root_directory)
# try:
root_directory = self.root_directory.format(**kwargs)
# except:
# if nnUnet_is_label:
# self.root_directory = self.root_directory.format(label_or_image="labels")
# else:
# self.root_directory = self.root_directory.format(label_or_image="images")
root_directory = self.root_directory.format(**kwargs) #replace the {} with the kwargs passed in from .put() (above)
out_path = pathlib.Path(root_directory, filename).as_posix()
out_dir = os.path.dirname(out_path)
if self.create_dirs and not os.path.exists(out_dir):
Expand Down
1 change: 0 additions & 1 deletion imgtools/modules/datagraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ def _get_df(self,
folder_save["patient_ID"] = df_connections["patient_ID_x"].iloc[0]
folder_save[f"series_{df_connections['modality_x'].iloc[0]}"] = comp[j]
folder_save[f"folder_{df_connections['modality_x'].iloc[0]}"] = df_connections["folder_x"].iloc[0]
print(f"folder_{df_connections['modality_x'].iloc[0]}", df_connections["folder_x"].iloc[0], "asdf")
temp_dfconn = df_connections[["series_y", "modality_y", "folder_y"]]
for k in range(len(temp_dfconn)):
#This loop stores connection of the CT
Expand Down
24 changes: 14 additions & 10 deletions imgtools/modules/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __getitem__(self, idx):
def __repr__(self):
return f"<Segmentation with ROIs: {self.roi_names!r}>"

def generate_sparse_mask(self) -> SparseMask:
def generate_sparse_mask(self, verbose=False) -> SparseMask:
"""
Generate a sparse mask from the contours, taking the argmax of all overlaps
Expand All @@ -107,23 +107,27 @@ def generate_sparse_mask(self) -> SparseMask:
print(roi_names)

sparsemask_arr = np.zeros(mask_arr.shape[1:])

# voxels_with_overlap = {}

if verbose:
voxels_with_overlap = set()
if len(mask_arr.shape) == 4:
for i in range(mask_arr.shape[0]):
slice = mask_arr[i, :, :, :]
slice *= list(roi_names.values())[i] # everything is 0 or 1, so this is fine to convert filled voxels to label indices
# res = self._max_adder(sparsemask_arr, slice)
# sparsemask_arr = res[0]
# for e in res[1]:
# voxels_with_overlap.add(e)
sparsemask_arr = np.fmax(sparsemask_arr, slice) # elementwise maximum
if verbose:
res = self._max_adder(sparsemask_arr, slice)
sparsemask_arr = res[0]
for e in res[1]:
voxels_with_overlap.add(e)
else:
sparsemask_arr = np.fmax(sparsemask_arr, slice) # elementwise maximum
else:
sparsemask_arr = mask_arr

sparsemask = SparseMask(sparsemask_arr, roi_names)
# if len(voxels_with_overlap) != 0:
# raise Warning(f"{len(voxels_with_overlap)} voxels have overlapping contours.")
if verbose:
if len(voxels_with_overlap) != 0:
raise Warning(f"{len(voxels_with_overlap)} voxels have overlapping contours.")
return sparsemask

def _max_adder(self, arr_1: np.ndarray, arr_2: np.ndarray) -> Tuple[np.ndarray, Set[Tuple[int, int, int]]]:
Expand Down
7 changes: 3 additions & 4 deletions imgtools/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,10 +337,8 @@ def __init__(self,
self.output[colname_process] = ImageSubjectFileOutput(pathlib.Path(root_directory,"{subject_id}",extension.split(".")[0]).as_posix(),
filename_format=colname_process+"{}.nii.gz".format(extension))
else:
self.output[colname_process] = ImageSubjectFileOutput(pathlib.Path(root_directory,"{label_or_image}Tr").as_posix(),
self.output[colname_process] = ImageSubjectFileOutput(pathlib.Path(root_directory,"{label_or_image}{train_or_test}").as_posix(),
filename_format="{study_name}_{index}_{modality_index}.nii.gz")
# self.output[colname_process] = ImageFileOutput(os.path.join(root_directory,extension.split(".")[0]),
# filename_format="{subject_id}_"+"{}.nrrd".format(extension))

def __call__(self,
subject_id: str,
Expand All @@ -349,9 +347,10 @@ def __call__(self,
is_mask: bool = False,
mask_label: Optional[str] = "",
label_or_image: str="images",
train_or_test: str="Tr",
nnUnet_info: Dict=None):

self.output[output_stream](subject_id, img, is_mask=is_mask, mask_label=mask_label, label_or_image=label_or_image, nnUnet_info=nnUnet_info)
self.output[output_stream](subject_id, img, is_mask=is_mask, mask_label=mask_label, label_or_image=label_or_image, train_or_test=train_or_test, nnUnet_info=nnUnet_info)

class NumpyOutput(BaseOutput):
"""NumpyOutput class processed images as NumPy files.
Expand Down
9 changes: 9 additions & 0 deletions imgtools/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,14 @@ def parser():

parser.add_argument("--show_progress", action="store_true",
help="Whether to print progress to standard output.")

parser.add_argument("--nnunet_study_name", type=str, default=None,
help="Name of the study to be used for nn-Unet.")

parser.add_argument("train_size", type=float, default=1.0,
help="The proportion of data to be used for training, as a decimal.")

parser.add_argument("random_state", type=int, default=42,
help="The random state to be used for the train-test-split.")

return parser.parse_known_args()[0]
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ SimpleITK
tqdm
torch
torchio
scikit-learn

0 comments on commit 7bc5404

Please sign in to comment.