Skip to content

Commit

Permalink
tried to save sparse mask, did some stuff with nnunet output format
Browse files Browse the repository at this point in the history
Former-commit-id: 6b54305
  • Loading branch information
fishingguy456 committed Jun 9, 2022
1 parent b432ee1 commit 8b3b2db
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 43 deletions.
45 changes: 33 additions & 12 deletions examples/autotest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def __init__(self,
show_progress=False,
warn_on_error=False,
overwrite=False,
generate_sparsemask=False):
generate_sparsemask=False,
nnUnet_info={}):

super().__init__(
n_jobs=n_jobs,
Expand All @@ -54,6 +55,7 @@ def __init__(self,
self.spacing = spacing
self.existing = [None] #self.existing_patients()
self.generate_sparsemask = generate_sparsemask
self.nnUnet_info = nnUnet_info

#input operations
self.input = ImageAutoInput(input_directory, modalities, n_jobs, visualize)
Expand All @@ -69,7 +71,7 @@ def __init__(self,
self.make_binary_mask = StructureSetToSegmentation(roi_names=[], continuous=False) # "GTV-.*"

# output ops
self.output = ImageAutoOutput(self.output_directory, self.output_streams)
self.output = ImageAutoOutput(self.output_directory, self.output_streams, self.nnUnet_info)

#Make a directory
if not os.path.exists(pathlib.Path(self.output_directory,".temp").as_posix()):
Expand Down Expand Up @@ -103,6 +105,7 @@ def process_one_subject(self, subject_id):
metadata = {}
subject_modalities = set()
num_rtstructs = 0
self.nnUnet_info["index"] += 1
for i, colname in enumerate(self.output_streams):
modality = colname.split("_")[0]
subject_modalities.add(modality)
Expand Down Expand Up @@ -130,7 +133,10 @@ def process_one_subject(self, subject_id):
print(image.GetSize())
image = self.resample(image)
#Saving the output
self.output(subject_id, image, output_stream)
if self.nnUnet_info == {}:
self.output(subject_id, image, output_stream)
else:
self.output(subject_id, image, output_stream, nnUnet_info=self.nnUnet_info)

if hasattr(read_results[i], "metadata") and read_results[i].metadata is not None:
metadata.update(read_results[i].metadata)
Expand Down Expand Up @@ -176,9 +182,9 @@ def process_one_subject(self, subject_id):
print(mask.GetSize())
mask_arr = np.transpose(sitk.GetArrayFromImage(mask))

if self.generate_sparsemask:
sparse_mask = mask.generate_sparse_mask()
save_path = pathlib.Path(self.output_directory, subject_id, "sparse_mask", "sparse_mask.nii.gz").as_posix()
# if self.generate_sparsemask:
# sparse_mask = mask.generate_sparse_mask()
# save_path = pathlib.Path(self.output_directory, subject_id, "sparse_mask", "sparse_mask.nii.gz").as_posix()
# sparse_mask_nifti = nib.Nifti1Image(sparse_mask.mask_array, affine=np.eye(4))
# nib.save(sparse_mask_nifti, save_path)
# self.output("sparse_mask", sparse_mask, output_stream, "sparse_mask")
Expand All @@ -194,11 +200,14 @@ def process_one_subject(self, subject_id):
new_mask.CopyInformation(mask)
new_mask = Segmentation(new_mask)
mask_to_process = new_mask
if not mult_conn:
# self.output(roi_names_list[i], mask_to_process, output_stream)
self.output(subject_id, mask_to_process, output_stream, True, roi_names_list[i])
if self.nnUnet_info == {}:
if not mult_conn:
# self.output(roi_names_list[i], mask_to_process, output_stream)
self.output(subject_id, mask_to_process, output_stream, True, roi_names_list[i])
else:
self.output(f"{subject_id}_{num}", mask_to_process, output_stream, True, roi_names_list[i])
else:
self.output(f"{subject_id}_{num}", mask_to_process, output_stream, True, roi_names_list[i])
self.output(subject_id, mask_to_process, output_stream, nnUnet_info=self.nnUnet_info, nnUnet_is_label=True)

if hasattr(structure_set, "metadata") and structure_set.metadata is not None:
metadata.update(structure_set.metadata)
Expand Down Expand Up @@ -276,13 +285,25 @@ def run(self):
modalities="CT,RTSTRUCT",
visualize=False,
overwrite=True,
generate_sparsemask=True)
generate_sparsemask=True,
nnUnet_info={"study name": "NSCLC-Radiomics-Interobserver1",
"index": 0,
"modality": "0000"})

# pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/hnscc_testing/HNSCC",
# output_directory="C:/Users/qukev/BHKLAB/hnscc_testing_output",
# modalities="CT,RTSTRUCT",
# visualize=False,
# overwrite=True)
# overwrite=True,
# generate_sparsemask=True)

# pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/dataset/manifest-1598890146597/NSCLC-Radiomics-Interobserver1",
# output_directory="C:/Users/qukev/BHKLAB/autopipelineoutput",
# modalities="CT,RTSTRUCT",
# visualize=False,
# overwrite=True,
# generate_sparsemask=True)

# pipeline = AutoPipeline(input_directory="C:/Users/qukev/BHKLAB/hnscc_pet/PET",
# output_directory="C:/Users/qukev/BHKLAB/hnscc_pet_output",
# modalities="CT,PT,RTDOSE",
Expand Down
12 changes: 9 additions & 3 deletions imgtools/io/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,22 @@ def __init__(self, root_directory, filename_format="{subject_id}.nii.gz", create

shutil.rmtree(os.path.dirname(self.root_directory))

def put(self, subject_id, image, is_mask=False, mask_label="",**kwargs):
def put(self, subject_id, image, is_mask=False, nnUnet_info={}, nnUnet_is_label=False, mask_label="",**kwargs):
if is_mask:
self.filename_format = mask_label+".nii.gz"
out_path = self._get_path_from_subject_id(subject_id, **kwargs)
if nnUnet_is_label and nnUnet_info != {}:
self.filename_format = f"{nnUnet_info['study name']}_{nnUnet_info['index']}.nii.gz"
out_path = self._get_path_from_subject_id(subject_id, nnUnet_is_label=nnUnet_is_label, **kwargs)
sitk.WriteImage(image, out_path, self.compress)

def _get_path_from_subject_id(self, subject_id, **kwargs):
def _get_path_from_subject_id(self, subject_id, nnUnet_is_label=False, **kwargs):
# out_filename = self.filename_format.format(subject_id=subject_id, **kwargs)
self.root_directory = self.root_directory.format(subject_id=subject_id,
**kwargs)
if nnUnet_is_label:
self.root_directory.format(label_or_image="labels")
else:
self.root_directory.format(label_or_image="images")
out_path = pathlib.Path(self.root_directory, self.filename_format).as_posix()
out_dir = os.path.dirname(out_path)
if self.create_dirs and not os.path.exists(out_dir):
Expand Down
21 changes: 13 additions & 8 deletions imgtools/modules/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,19 @@ def generate_sparse_mask(self) -> SparseMask:
sparsemask_arr = np.zeros(mask_arr.shape[1:])

# voxels_with_overlap = {}
for i in range(mask_arr.shape[0]):
slice = mask_arr[i, :, :, :]
slice *= list(roi_names.values())[i] # everything is 0 or 1, so this is fine to convert filled voxels to label indices
# res = self._max_adder(sparsemask_arr, slice)
# sparsemask_arr = res[0]
# for e in res[1]:
# voxels_with_overlap.add(e)
sparsemask_arr = np.fmax(sparsemask_arr, slice) # elementwise maximum
if len(mask_arr.shape) == 4:
for i in range(mask_arr.shape[0]):
slice = mask_arr[i, :, :, :]
slice *= list(roi_names.values())[i] # everything is 0 or 1, so this is fine to convert filled voxels to label indices
# res = self._max_adder(sparsemask_arr, slice)
# sparsemask_arr = res[0]
# for e in res[1]:
# voxels_with_overlap.add(e)
sparsemask_arr = np.fmax(sparsemask_arr, slice) # elementwise maximum
print("A")
else:
sparsemask_arr = mask_arr
print("B")

sparsemask = SparseMask(sparsemask_arr, roi_names)
# if len(voxels_with_overlap) != 0:
Expand Down
34 changes: 17 additions & 17 deletions imgtools/modules/structureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,27 +190,27 @@ def to_segmentation(self, reference_image: sitk.Image,
for contour in mask_points:
z, slice_points = np.unique(contour[:, 0]), contour[:, 1:]
# rounding errors for points on the boundary
if z == mask.shape[0]:
z -= 1
elif z == -1:
z += 1
elif z > mask.shape[0] or z < -1:
raise IndexError(f"{z} index is out of bounds for image sized {mask.shape}.")
# if z == mask.shape[0]:
# z -= 1
# elif z == -1:
# z += 1
# elif z > mask.shape[0] or z < -1:
# raise IndexError(f"{z} index is out of bounds for image sized {mask.shape}.")

# if the contour spans only 1 z-slice
if len(z) == 1:
z = int(np.floor(z[0]))
slice_mask = polygon2mask(size[1:-1], slice_points)
mask[z, :, :, label] += slice_mask
else:
raise ValueError("This contour is corrupted and spans across 2 or more slices.")

# This is the old version of z index parsing. Kept for backup
# # if the contour spans only 1 z-slice
# if len(z) == 1:
# # assert len(z) == 1, f"This contour ({name}) spreads across more than 1 slice."
# z = z[0]
# z = int(np.floor(z[0]))
# slice_mask = polygon2mask(size[1:-1], slice_points)
# mask[z, :, :, label] += slice_mask
# else:
# raise ValueError("This contour is corrupted and spans across 2 or more slices.")

# This is the old version of z index parsing. Kept for backup
if len(z) == 1:
# assert len(z) == 1, f"This contour ({name}) spreads across more than 1 slice."
z = z[0]
slice_mask = polygon2mask(size[1:-1], slice_points)
mask[z, :, :, label] += slice_mask


mask[mask > 1] = 1
Expand Down
12 changes: 9 additions & 3 deletions imgtools/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ class ImageAutoOutput:
"""
def __init__(self,
root_directory: str,
output_streams: List[str]):
output_streams: List[str],
nnUnet_info: Dict = None):

# File types
self.file_name = file_name_convention()
Expand All @@ -332,8 +333,12 @@ def __init__(self,
# Not considering colnames ending with alphanumeric
colname_process = ("_").join([item for item in colname.split("_") if item.isnumeric()==False])
extension = self.file_name[colname_process]
self.output[colname_process] = ImageSubjectFileOutput(pathlib.Path(root_directory,"{subject_id}",extension.split(".")[0]).as_posix(),
if not nnUnet_info:
self.output[colname_process] = ImageSubjectFileOutput(pathlib.Path(root_directory,"{subject_id}",extension.split(".")[0]).as_posix(),
filename_format=colname_process+"{}.nii.gz".format(extension))
else:
self.output[colname_process] = ImageSubjectFileOutput(pathlib.Path(root_directory,"{label_or_image}Tr").as_posix(),
filename_format=f"{nnUnet_info['study name']}_{nnUnet_info['index']}_{nnUnet_info['modality']}.nii.gz")
# self.output[colname_process] = ImageFileOutput(os.path.join(root_directory,extension.split(".")[0]),
# filename_format="{subject_id}_"+"{}.nrrd".format(extension))

Expand All @@ -342,9 +347,10 @@ def __call__(self,
img: sitk.Image,
output_stream,
is_mask: bool = False,
nnUnet_is_label: bool=False,
mask_label: Optional[str] = ""):

self.output[output_stream](subject_id, img, is_mask=is_mask, mask_label=mask_label)
self.output[output_stream](subject_id, img, is_mask=is_mask, mask_label=mask_label, nnUnet_is_label=nnUnet_is_label)

class NumpyOutput(BaseOutput):
"""NumpyOutput class processed images as NumPy files.
Expand Down

0 comments on commit 8b3b2db

Please sign in to comment.