Skip to content

Commit

Permalink
sparse mask global labelling for contour name: index
Browse files Browse the repository at this point in the history
  • Loading branch information
fishingguy456 committed Jun 15, 2022
1 parent eeae5b7 commit d65c6d8
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
15 changes: 11 additions & 4 deletions examples/autotest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def __init__(self,
#Make a directory
if not os.path.exists(pathlib.Path(self.output_directory,".temp").as_posix()):
os.mkdir(pathlib.Path(self.output_directory,".temp").as_posix())

self.existing_roi_names = {}


def process_one_subject(self, subject_id):
Expand Down Expand Up @@ -244,18 +246,23 @@ def process_one_subject(self, subject_id):

# make_binary_mask relative to ct/pet
if conn_to == "CT" or conn_to == "MR":
mask = self.make_binary_mask(structure_set, image)
mask = self.make_binary_mask(structure_set, image, self.existing_roi_names)
elif conn_to == "PT":
mask = self.make_binary_mask(structure_set, pet)
mask = self.make_binary_mask(structure_set, pet, self.existing_roi_names)
else:
raise ValueError("You need to pass a reference CT or PT/PET image to map contours to.")

for name in mask.roi_names.keys():
if name not in self.existing_roi_names.keys():
self.existing_roi_names[name] = len(self.existing_roi_names)
mask.existing_roi_names = self.existing_roi_names

# save output
print(mask.GetSize())
mask_arr = np.transpose(sitk.GetArrayFromImage(mask))

if self.is_nnunet:
sparse_mask = mask.generate_sparse_mask(self.label_names).mask_array
sparse_mask = mask.generate_sparse_mask().mask_array
sparse_mask = sitk.GetImageFromArray(sparse_mask) #convert the nparray to sitk image
if "_".join(subject_id.split("_")[1::]) in self.train:
self.output(subject_id, sparse_mask, output_stream, nnunet_info=self.nnunet_info, label_or_image="labels") #rtstruct is label for nnunet
Expand Down Expand Up @@ -383,7 +390,7 @@ def run(self):
modalities="CT,RTSTRUCT",
visualize=False,
overwrite=True,
# is_nnunet=True,
is_nnunet=True,
train_size=0.5,
label_names={"GTV":"GTV.*", "Brainstem": "Brainstem.*"})

Expand Down
18 changes: 9 additions & 9 deletions imgtools/modules/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def map_over_labels(segmentation, f, include_background=False, return_segmentati


class Segmentation(sitk.Image):
def __init__(self, segmentation, roi_names=None):
def __init__(self, segmentation, roi_names=None, existing_roi_names=None):
super().__init__(segmentation)
self.num_labels = self.GetNumberOfComponentsPerPixel()
if not roi_names:
Expand All @@ -48,6 +48,7 @@ def __init__(self, segmentation, roi_names=None):
for i in range(1, self.num_labels+1):
if i not in self.roi_names.values():
self.roi_names[f"label_{i}"] = i
self.existing_roi_names = existing_roi_names

def get_label(self, label=None, name=None, relabel=False):
if label is None and name is None:
Expand Down Expand Up @@ -86,7 +87,7 @@ def __getitem__(self, idx):
def __repr__(self):
return f"<Segmentation with ROIs: {self.roi_names!r}>"

def generate_sparse_mask(self, label_names, verbose=False) -> SparseMask:
def generate_sparse_mask(self, verbose=False) -> SparseMask:
"""
Generate a sparse mask from the contours, taking the argmax of all overlaps
Expand All @@ -100,12 +101,11 @@ def generate_sparse_mask(self, label_names, verbose=False) -> SparseMask:
SparseMask
The sparse mask object.
"""
# print("asdlkfjalkfsjg", self.roi_names)
mask_arr = np.transpose(sitk.GetArrayFromImage(self))
if list(self.roi_names.values())[0] == 0:
roi_names = {k: v+1 for k, v in self.roi_names.items()}
else:
roi_names = self.roi_names
print(roi_names)
for name in self.roi_names.keys():
self.roi_names[name] = self.existing_roi_names[name]
# print(self.roi_names)

sparsemask_arr = np.zeros(mask_arr.shape[1:])

Expand All @@ -115,7 +115,7 @@ def generate_sparse_mask(self, label_names, verbose=False) -> SparseMask:
if len(mask_arr.shape) == 4:
for i in range(mask_arr.shape[0]):
slice = mask_arr[i, :, :, :]
slice *= list(roi_names.values())[i] # everything is 0 or 1, so this is fine to convert filled voxels to label indices
slice *= list(self.roi_names.values())[i] # everything is 0 or 1, so this is fine to convert filled voxels to label indices
if verbose:
res = self._max_adder(sparsemask_arr, slice)
sparsemask_arr = res[0]
Expand All @@ -126,7 +126,7 @@ def generate_sparse_mask(self, label_names, verbose=False) -> SparseMask:
else:
sparsemask_arr = mask_arr

sparsemask = SparseMask(sparsemask_arr, roi_names)
sparsemask = SparseMask(sparsemask_arr, self.roi_names)

if verbose:
if len(voxels_with_overlap) != 0:
Expand Down
8 changes: 4 additions & 4 deletions imgtools/modules/structureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def _assign_labels(self, names, force_missing=False):
def to_segmentation(self, reference_image: sitk.Image,
roi_names: Dict[str, str] = None,
force_missing: bool = False,
continuous: bool = True) -> Segmentation:
continuous: bool = True,
existing_roi_names: Dict[str, int] = None) -> Segmentation:
"""Convert the structure set to a Segmentation object.
Parameters
Expand Down Expand Up @@ -154,7 +155,7 @@ def to_segmentation(self, reference_image: sitk.Image,
# print(self.roi_points)

seg_roi_names = {}
print(roi_names)
# print(roi_names)
if roi_names != {} and isinstance(roi_names, dict):
for i, (name, label_list) in enumerate(labels.items()):
for label in label_list:
Expand Down Expand Up @@ -209,8 +210,7 @@ def to_segmentation(self, reference_image: sitk.Image,
mask[mask > 1] = 1
mask = sitk.GetImageFromArray(mask, isVector=True)
mask.CopyInformation(reference_image)
print("adams",seg_roi_names)
mask = Segmentation(mask, roi_names=seg_roi_names)
mask = Segmentation(mask, roi_names=seg_roi_names, existing_roi_names=existing_roi_names) #in the segmentation, pass all the existing roi names and then process is in the segmentation class

return mask

Expand Down
5 changes: 3 additions & 2 deletions imgtools/ops/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,7 @@ def __init__(self,
self.force_missing = force_missing
self.continuous = continuous

def __call__(self, structure_set: StructureSet, reference_image: sitk.Image) -> Segmentation:
def __call__(self, structure_set: StructureSet, reference_image: sitk.Image, existing_roi_names: Dict[str, int]) -> Segmentation:
"""Convert the structure set to a Segmentation object.
Parameters
Expand All @@ -1487,7 +1487,8 @@ def __call__(self, structure_set: StructureSet, reference_image: sitk.Image) ->
return structure_set.to_segmentation(reference_image,
roi_names=self.roi_names,
force_missing=self.force_missing,
continuous=self.continuous)
continuous=self.continuous,
existing_roi_names=existing_roi_names)

class MapOverLabels(BaseOp):
"""MapOverLabels operation class:
Expand Down

0 comments on commit d65c6d8

Please sign in to comment.