Skip to content

Commit

Permalink
testing out sparse mask
Browse files Browse the repository at this point in the history
  • Loading branch information
fishingguy456 committed May 30, 2022
1 parent 8dab6bc commit 19b0ad5
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 62 deletions.
7 changes: 6 additions & 1 deletion examples/autotest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import glob
import pickle
import numpy as np
import sys

from argparse import ArgumentParser
import SimpleITK as sitk
Expand All @@ -11,6 +12,7 @@
from imgtools.pipeline import Pipeline
from joblib import Parallel, delayed
from imgtools.modules import Segmentation
from torch import sparse_coo_tensor

###############################################################
# Example usage:
Expand Down Expand Up @@ -154,7 +156,10 @@ def process_one_subject(self, subject_id):
# save output
print(mask.GetSize())
mask_arr = np.transpose(sitk.GetArrayFromImage(mask))

sparse_mask = mask.generate_sparse_mask()
# np.set_printoptions(threshold=sys.maxsize)
# print(sparse_mask.mask_array.shape)
# print(sparse_mask.mask_array[350:360,290:300,93])
# if there is only one ROI, sitk.GetArrayFromImage() will return a 3d array instead of a 4d array with one slice
if len(mask_arr.shape) == 3:
mask_arr = mask_arr.reshape(1, mask_arr.shape[0], mask_arr.shape[1], mask_arr.shape[2])
Expand Down
67 changes: 67 additions & 0 deletions imgtools/modules/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import numpy as np
import SimpleITK as sitk

from .sparsemask import SparseMask

from ..utils import array_to_image, image_to_array
from typing import Dict, List, Optional, Union, Tuple, Set


def accepts_segmentations(f):
Expand Down Expand Up @@ -81,3 +84,67 @@ def __getitem__(self, idx):

def __repr__(self):
return f"<Segmentation with ROIs: {self.roi_names!r}>"

def generate_sparse_mask(self) -> SparseMask:
"""
Generate a sparse mask from the contours, taking the argmax of all overlaps
Parameters
----------
mask
Segmentation object to build sparse mask from
Returns
-------
SparseMask
The sparse mask object.
"""
mask_arr = np.transpose(sitk.GetArrayFromImage(self))
if list(self.roi_names.values())[0] == 0:
roi_names = {k: v+1 for k, v in self.roi_names.items()}
else:
roi_names = self.roi_names
print(roi_names)

sparsemask_arr = np.zeros(mask_arr.shape[1:])

# voxels_with_overlap = {}
for i in range(mask_arr.shape[0]):
slice = mask_arr[i, :, :, :]
slice *= list(roi_names.values())[i] # everything is 0 or 1, so this is fine to convert filled voxels to label indices
# res = self._max_adder(sparsemask_arr, slice)
# sparsemask_arr = res[0]
# for e in res[1]:
# voxels_with_overlap.add(e)
sparsemask_arr = np.fmax(sparsemask_arr, slice) # elementwise maximum

sparsemask = SparseMask(sparsemask_arr, roi_names)
# if len(voxels_with_overlap) != 0:
# raise Warning(f"{len(voxels_with_overlap)} voxels have overlapping contours.")
return sparsemask

def _max_adder(self, arr_1: np.ndarray, arr_2: np.ndarray) -> Tuple[np.ndarray, Set[Tuple[int, int, int]]]:
"""
Takes the maximum of two 3D arrays elementwise and returns the resulting array and a list of voxels that have overlapping contours in a set
Parameters
----------
arr_1
First array to take maximum of
arr_2
Second array to take maximum of
Returns
-------
Tuple[np.ndarray, Set[Tuple[int, int, int]]]
The resulting array and a list of voxels that have overlapping contours in a set
"""
res = np.zeros(arr_1.shape)
overlaps = {} #set of tuples of the coords that have overlap
for i in range(arr_1.shape[0]):
for j in range(arr_1.shape[1]):
for k in range(arr_1.shape[2]):
if arr_1[i, j, k] != 0 and arr_2[i, j, k] != 0:
overlaps.add((i, j, k))
res[i, j, k] = max(arr_1[i, j, k], arr_2[i, j, k])
return res, overlaps
61 changes: 0 additions & 61 deletions imgtools/modules/structureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from skimage.draw import polygon2mask

from .segmentation import Segmentation
from .sparsemask import SparseMask
from ..utils import physical_points_to_idxs


Expand Down Expand Up @@ -172,66 +171,6 @@ def to_segmentation(self, reference_image: sitk.Image,
mask = Segmentation(mask, roi_names=seg_roi_names)

return mask

def generate_sparse_mask(self, mask: Segmentation) -> SparseMask:
"""
Generate a sparse mask from the contours, taking the argmax of all overlaps
Parameters
----------
mask
Segmentation object to build sparse mask from
Returns
-------
SparseMask
The sparse mask object.
"""
mask_arr = np.transpose(sitk.GetArrayFromImage(mask))
roi_names = {k: v+1 for k, v in mask.roi_names.items()}

sparsemask_arr = np.zeros(mask_arr.shape[1:])

# voxels_with_overlap = {}
for i in len(mask_arr.shape[0]):
slice = mask_arr[i, :, :, :]
slice *= list(roi_names.values())[i] # everything is 0 or 1, so this is fine to convert filled voxels to label indices
# res = self._max_adder(sparsemask_arr, slice)
# sparsemask_arr = res[0]
# for e in res[1]:
# voxels_with_overlap.add(e)
sparsemask_arr = np.fmax(sparsemask_arr, slice) # elementwise maximum

sparsemask = SparseMask(sparsemask_arr, roi_names)
# if len(voxels_with_overlap) != 0:
# raise Warning(f"{len(voxels_with_overlap)} voxels have overlapping contours.")
return sparsemask

def _max_adder(self, arr_1: np.ndarray, arr_2: np.ndarray) -> Tuple[np.ndarray, Set[Tuple[int, int, int]]]:
"""
Takes the maximum of two 3D arrays elementwise and returns the resulting array and a list of voxels that have overlapping contours in a set
Parameters
----------
arr_1
First array to take maximum of
arr_2
Second array to take maximum of
Returns
-------
Tuple[np.ndarray, Set[Tuple[int, int, int]]]
The resulting array and a list of voxels that have overlapping contours in a set
"""
res = np.zeros(arr_1.shape)
overlaps = {} #set of tuples of the coords that have overlap
for i in range(arr_1.shape[0]):
for j in range(arr_1.shape[1]):
for k in range(arr_1.shape[2]):
if arr_1[i, j, k] != 0 and arr_2[i, j, k] != 0:
overlaps.add((i, j, k))
res[i, j, k] = max(arr_1[i, j, k], arr_2[i, j, k])
return res, overlaps

def __repr__(self):
return f"<StructureSet with ROIs: {self.roi_names!r}>"

0 comments on commit 19b0ad5

Please sign in to comment.