Skip to content

Commit

Permalink
compliant with nnunet directory structure
Browse files Browse the repository at this point in the history
Former-commit-id: 1407e1d
  • Loading branch information
fishingguy456 committed Jun 10, 2022
1 parent 8b3b2db commit 0de879b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 43 deletions.
64 changes: 35 additions & 29 deletions examples/autotest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import glob
import pickle
import struct
from imgtools.io.common import file_name_convention
import numpy as np
import sys

Expand All @@ -15,8 +14,8 @@
from joblib import Parallel, delayed
from imgtools.modules import Segmentation
from torch import sparse_coo_tensor
import nibabel as nib

from imgtools.io.common import file_name_convention
###############################################################
# Example usage:
# python radcure_simple.py ./data/RADCURE/data ./RADCURE_output
Expand All @@ -41,7 +40,7 @@ def __init__(self,
warn_on_error=False,
overwrite=False,
generate_sparsemask=False,
nnUnet_info={}):
nnUnet_info=None):

super().__init__(
n_jobs=n_jobs,
Expand All @@ -55,7 +54,10 @@ def __init__(self,
self.spacing = spacing
self.existing = [None] #self.existing_patients()
self.generate_sparsemask = generate_sparsemask
self.nnUnet_info = nnUnet_info
if nnUnet_info:
self.nnUnet_info = nnUnet_info
self.nnUnet_info["modalities"] = {"CT": "0000"}
self.nnUnet_info["index"] = 0

#input operations
self.input = ImageAutoInput(input_directory, modalities, n_jobs, visualize)
Expand Down Expand Up @@ -132,15 +134,22 @@ def process_one_subject(self, subject_id):
image = extractor.Execute(image)
print(image.GetSize())
image = self.resample(image)
#Saving the output
if self.nnUnet_info == {}:
self.output(subject_id, image, output_stream)
else:
self.output(subject_id, image, output_stream, nnUnet_info=self.nnUnet_info)

if hasattr(read_results[i], "metadata") and read_results[i].metadata is not None:
metadata.update(read_results[i].metadata)

#modality is MR and the user has selected to have nnUnet output
if self.nnUnet_info:
if modality == "MR":
self.nnUnet_info['current_modality'] = metadata["AcquisitionContrast"]
if not metadata["AcquisitionContrast"] in self.nnUnet_info["modalities"].keys():
self.nnUnet_info["modalities"][metadata["AcquisitionContrast"]] = str(len(self.nnUnet_info["modalities"])).zfill(4)
else:
self.nnUnet_info['current_modality'] = modality #CT
self.output(subject_id, image, output_stream, nnUnet_info=self.nnUnet_info)
else:
self.output(subject_id, image, output_stream)

metadata[f"size_{output_stream}"] = str(image.GetSize())


Expand Down Expand Up @@ -182,32 +191,31 @@ def process_one_subject(self, subject_id):
print(mask.GetSize())
mask_arr = np.transpose(sitk.GetArrayFromImage(mask))

# if self.generate_sparsemask:
# sparse_mask = mask.generate_sparse_mask()
# save_path = pathlib.Path(self.output_directory, subject_id, "sparse_mask", "sparse_mask.nii.gz").as_posix()
if self.nnUnet_info:
sparse_mask = mask.generate_sparse_mask().mask_array
sparse_mask = sitk.GetImageFromArray(sparse_mask)
# save_path = pathlib.Path(self.output_directory, subject_id, "sparse_mask", "sparse_mask.nii.gz").as_posix()
self.output(subject_id, sparse_mask, output_stream, nnUnet_info=self.nnUnet_info, nnUnet_is_label=True)
# sparse_mask_nifti = nib.Nifti1Image(sparse_mask.mask_array, affine=np.eye(4))
# nib.save(sparse_mask_nifti, save_path)
# self.output("sparse_mask", sparse_mask, output_stream, "sparse_mask")

else:
# if there is only one ROI, sitk.GetArrayFromImage() will return a 3d array instead of a 4d array with one slice
if len(mask_arr.shape) == 3:
mask_arr = mask_arr.reshape(1, mask_arr.shape[0], mask_arr.shape[1], mask_arr.shape[2])

print(mask_arr.shape)
roi_names_list = list(mask.roi_names.keys())
for i in range(mask_arr.shape[0]):
new_mask = sitk.GetImageFromArray(np.transpose(mask_arr[i]))
new_mask.CopyInformation(mask)
new_mask = Segmentation(new_mask)
mask_to_process = new_mask
if self.nnUnet_info == {}:
if len(mask_arr.shape) == 3:
mask_arr = mask_arr.reshape(1, mask_arr.shape[0], mask_arr.shape[1], mask_arr.shape[2])

print(mask_arr.shape)
roi_names_list = list(mask.roi_names.keys())
for i in range(mask_arr.shape[0]):
new_mask = sitk.GetImageFromArray(np.transpose(mask_arr[i]))
new_mask.CopyInformation(mask)
new_mask = Segmentation(new_mask)
mask_to_process = new_mask
if not mult_conn:
# self.output(roi_names_list[i], mask_to_process, output_stream)
self.output(subject_id, mask_to_process, output_stream, True, roi_names_list[i])
else:
self.output(f"{subject_id}_{num}", mask_to_process, output_stream, True, roi_names_list[i])
else:
self.output(subject_id, mask_to_process, output_stream, nnUnet_info=self.nnUnet_info, nnUnet_is_label=True)

if hasattr(structure_set, "metadata") and structure_set.metadata is not None:
metadata.update(structure_set.metadata)
Expand Down Expand Up @@ -286,9 +294,7 @@ def run(self):
visualize=False,
overwrite=True,
generate_sparsemask=True,
nnUnet_info={"study name": "NSCLC-Radiomics-Interobserver1",
"index": 0,
"modality": "0000"})
nnUnet_info={"study name": "NSCLC-Radiomics-Interobserver1"})

# pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/hnscc_testing/HNSCC",
# output_directory="C:/Users/qukev/BHKLAB/hnscc_testing_output",
Expand Down
32 changes: 20 additions & 12 deletions imgtools/io/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,36 @@ def __init__(self, root_directory, filename_format="{subject_id}.nii.gz", create
self.filename_format = filename_format
self.create_dirs = create_dirs
self.compress = compress
if os.path.exists(self.root_directory)\
and os.path.basename(os.path.dirname(self.root_directory)) == "{subject_id}":
if os.path.exists(self.root_directory):
if os.path.basename(os.path.dirname(self.root_directory)) == "{subject_id}":
shutil.rmtree(os.path.dirname(self.root_directory))
elif "{label_or_image}" in os.path.basename(self.root_directory):
shutil.rmtree(self.root_directory)
#delete the folder called {subject_id} that was made in the original BaseWriter

shutil.rmtree(os.path.dirname(self.root_directory))

def put(self, subject_id, image, is_mask=False, nnUnet_info={}, nnUnet_is_label=False, mask_label="",**kwargs):
def put(self, subject_id, image, is_mask=False, nnUnet_info=None, nnUnet_is_label=False, mask_label="",**kwargs):
if is_mask:
self.filename_format = mask_label+".nii.gz"
if nnUnet_is_label and nnUnet_info != {}:
self.filename_format = f"{nnUnet_info['study name']}_{nnUnet_info['index']}.nii.gz"
if nnUnet_info:
if nnUnet_is_label:
self.filename_format = f"{nnUnet_info['study name']}_{nnUnet_info['index']}.nii.gz"
else:
# f"{nnUnet_info['study name']}_{nnUnet_info['index']}_{nnUnet_info['modalities'][nnUnet_info['current_modality']]}.nii.gz"
self.filename_format = self.filename_format.format(study_name=nnUnet_info['study name'], index=nnUnet_info['index'], modality_index=nnUnet_info['modalities'][nnUnet_info['current_modality']])
out_path = self._get_path_from_subject_id(subject_id, nnUnet_is_label=nnUnet_is_label, **kwargs)
sitk.WriteImage(image, out_path, self.compress)

def _get_path_from_subject_id(self, subject_id, nnUnet_is_label=False, **kwargs):
# out_filename = self.filename_format.format(subject_id=subject_id, **kwargs)
self.root_directory = self.root_directory.format(subject_id=subject_id,
**kwargs)
if nnUnet_is_label:
self.root_directory.format(label_or_image="labels")
else:
self.root_directory.format(label_or_image="images")
try:
self.root_directory = self.root_directory.format(subject_id=subject_id,
**kwargs)
except:
if nnUnet_is_label:
self.root_directory = self.root_directory.format(label_or_image="labels")
else:
self.root_directory = self.root_directory.format(label_or_image="images")
out_path = pathlib.Path(self.root_directory, self.filename_format).as_posix()
out_dir = os.path.dirname(out_path)
if self.create_dirs and not os.path.exists(out_dir):
Expand Down
5 changes: 3 additions & 2 deletions imgtools/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(self,
filename_format=colname_process+"{}.nii.gz".format(extension))
else:
self.output[colname_process] = ImageSubjectFileOutput(pathlib.Path(root_directory,"{label_or_image}Tr").as_posix(),
filename_format=f"{nnUnet_info['study name']}_{nnUnet_info['index']}_{nnUnet_info['modality']}.nii.gz")
filename_format="{study_name}_{index}_{modality_index}.nii.gz")
# self.output[colname_process] = ImageFileOutput(os.path.join(root_directory,extension.split(".")[0]),
# filename_format="{subject_id}_"+"{}.nrrd".format(extension))

Expand All @@ -348,9 +348,10 @@ def __call__(self,
output_stream,
is_mask: bool = False,
nnUnet_is_label: bool=False,
nnUnet_info: Dict=None,
mask_label: Optional[str] = ""):

self.output[output_stream](subject_id, img, is_mask=is_mask, mask_label=mask_label, nnUnet_is_label=nnUnet_is_label)
self.output[output_stream](subject_id, img, is_mask=is_mask, mask_label=mask_label, nnUnet_is_label=nnUnet_is_label, nnUnet_info=nnUnet_info)

class NumpyOutput(BaseOutput):
"""NumpyOutput class processed images as NumPy files.
Expand Down

0 comments on commit 0de879b

Please sign in to comment.