Skip to content

Commit

Permalink
Refactor watershed mask creation using new API (#1694)
Browse files Browse the repository at this point in the history
Remove starfish.core.image.Filter.util.bin_open as there are no more users of this method.

Depends on #1692, #1693, #1684, #1688
Test plan: ISS notebook still yields 96 cells.
  • Loading branch information
Tony Tung authored Jan 14, 2020
1 parent a018986 commit e666d82
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 73 deletions.
21 changes: 0 additions & 21 deletions starfish/core/image/Filter/util.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,10 @@
from typing import Set, Tuple, Union

import numpy as np
from skimage.morphology import binary_opening, disk

from starfish.core.types import Axes, Number


def bin_open(img: np.ndarray, disk_size: int) -> np.ndarray:
"""
Performs binary opening of an image
img : np.ndarray
Image to filter.
masking_radius : int
Radius of the disk-shaped structuring element.
Returns
-------
np.ndarray :
Filtered image, same shape as input
"""
selem = disk(disk_size)
res = binary_opening(img, selem)
return res


def gaussian_kernel(shape: Tuple[int, int]=(3, 3), sigma: float=0.5):
"""
Returns a gaussian kernel of specified shape and standard deviation.
Expand Down
128 changes: 76 additions & 52 deletions starfish/core/image/Segment/watershed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
from showit import image
from skimage.morphology import disk, watershed

from starfish.core.image.Filter.util import bin_open
from starfish.core.imagestack.imagestack import ImageStack
from starfish.core.morphology import Filter
from starfish.core.morphology import Filter, Merge
from starfish.core.morphology.Binarize import ThresholdBinarize
from starfish.core.morphology.binary_mask import BinaryMaskCollection
from starfish.core.morphology.Filter.areafilter import AreaFilter
from starfish.core.morphology.Filter.min_distance_label import MinDistanceLabel
from starfish.core.morphology.Filter.structural_label import StructuralLabel
from starfish.core.morphology.label_image import LabelImage
from starfish.core.types import ArrayLike, Axes, Coordinates, FunctionSource, Levels, Number
from ._base import SegmentAlgorithm

Expand Down Expand Up @@ -83,24 +81,11 @@ def run(
disk_size_mask = None

self._segmentation_instance = _WatershedSegmenter(primary_images, nuclei)
label_image_array = self._segmentation_instance.segment(
return self._segmentation_instance.segment(
self.nuclei_threshold, self.input_threshold, size_lim, disk_size_markers,
disk_size_mask, self.min_distance
)

# we max-projected and squeezed the Z-plane so label_image.ndim == 2
physical_ticks: Mapping[Coordinates, ArrayLike[Number]] = {
coord: nuclei.xarray.coords[coord.value].data
for coord in (Coordinates.Y, Coordinates.X)
}

return BinaryMaskCollection.from_label_array_and_ticks(
label_image_array,
None,
physical_ticks,
None, # TODO: (ttung) this should really be logged.
)

def show(self, figsize: Tuple[int, int]=(10, 10)) -> None:
if isinstance(self._segmentation_instance, _WatershedSegmenter):
self._segmentation_instance.show(figsize=figsize)
Expand Down Expand Up @@ -135,11 +120,10 @@ def __init__(self, primary_images: ImageStack, nuclei: ImageStack) -> None:
level_method=Levels.SCALE_BY_IMAGE,
)

self.nuclei_thresholded: Optional[np.ndarray] = None # dtype: bool
self.markers: Optional[LabelImage] = None
self.markers: Optional[BinaryMaskCollection] = None
self.num_cells: Optional[int] = None
self.mask = None
self.segmented = None
self.mask: Optional[BinaryMaskCollection] = None
self.segmented: Optional[BinaryMaskCollection] = None

def segment(
self,
Expand All @@ -149,7 +133,7 @@ def segment(
disk_size_markers: Optional[int]=None, # TODO ambrosejcarr what is this doing?
disk_size_mask: Optional[int]=None, # TODO ambrosejcarr what is this doing?
min_dist: Optional[int] = None
) -> np.ndarray:
) -> BinaryMaskCollection:
"""Execute watershed cell segmentation.
Parameters
Expand All @@ -169,8 +153,8 @@ def segment(
Returns
-------
np.ndarray[int32] :
label image with same size and shape as self.nuclei_img
BinaryMaskCollection :
binary mask collection where each cell is a mask.
"""
min_allowed_size, max_allowed_size = size_lim
self.binarized_nuclei = self.filter_nuclei(nuclei_thresh, disk_size_markers)
Expand All @@ -181,9 +165,8 @@ def segment(
labeled_masks = MinDistanceLabel(min_dist, 1).run(self.binarized_nuclei)

area_filter = AreaFilter(min_area=min_allowed_size, max_area=max_allowed_size)
filtered_masks = area_filter.run(labeled_masks)
self.num_cells = len(filtered_masks)
self.markers = filtered_masks.to_label_image()
self.markers = area_filter.run(labeled_masks)
self.num_cells = len(self.markers)
self.mask = self.watershed_mask(stain_thresh, self.markers, disk_size_mask)
self.segmented = self.watershed(self.markers, self.mask)
return self.segmented
Expand Down Expand Up @@ -220,58 +203,89 @@ def filter_nuclei(self, nuclei_thresh: float, disk_size: Optional[int]) -> Binar
def watershed_mask(
self,
stain_thresh: Number,
markers: LabelImage,
markers: BinaryMaskCollection,
disk_size: Optional[int],
) -> np.ndarray:
) -> BinaryMaskCollection:
"""Create a watershed mask that is the union of the spot intensities above stain_thresh and
a marker image generated from nuclei
Parameters
----------
stain_thresh : Number
threshold to apply to the stain image
markers : LabelImage
markers : BinaryMaskCollection
markers image generated from nuclei
disk_size : Optional[int]
if provided, execute a morphological opening operation over the thresholded stain image
Returns
-------
np.ndarray[bool] :
thresholded stain image
BinaryMaskCollection :
watershed mask
"""
st = self.stain._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE) >= stain_thresh
markers_any = (markers.xarray > 0).values.squeeze(axis=0)
watershed_mask: np.ndarray = np.logical_or(st, markers_any) # dtype bool
thresholded_stain = ThresholdBinarize(stain_thresh).run(self.stain)
markers_and_stain = Merge.SimpleMerge().run([thresholded_stain, markers])
watershed_mask = Filter.Reduce(
"logical_or",
lambda shape: np.zeros(shape=shape, dtype=np.bool)
).run(markers_and_stain)
if disk_size is not None:
watershed_mask = bin_open(watershed_mask, disk_size)
disk_img = disk(disk_size)
watershed_mask = Filter.Map(
"morphology.binary_open",
disk_img,
module=FunctionSource.skimage
).run(watershed_mask)

return watershed_mask

def watershed(self, markers: LabelImage, watershed_mask: np.ndarray) -> np.ndarray:
def watershed(
self,
markers: BinaryMaskCollection,
watershed_mask: BinaryMaskCollection,
) -> BinaryMaskCollection:
"""Run watershed on the thresholded primary_images max projection
Parameters
----------
markers : LabelImage
markers : BinaryMaskCollection
markers image generated from nuclei
watershed_mask : np.ndarray[bool]
watershed_mask : BinaryMaskCollection
Mask array. only points at which mask == True will be labeled in the output.
Returns
-------
np.ndarray[np.int32] :
labeled image, each segment has a unique integer value
BinaryMaskCollection :
binary mask collection where each cell is a mask.
"""
assert len(watershed_mask) == 1

img = 1 - self.stain._squeezed_numpy(Axes.ROUND, Axes.CH, Axes.ZPLANE)
markers_label_array = markers.to_label_image()

res = watershed(
image=img,
markers=markers_label_array.xarray.values.squeeze(axis=0),
connectivity=np.ones((3, 3), bool),
mask=watershed_mask.uncropped_mask(0).squeeze(axis=0),
)

res = watershed(image=img,
markers=markers.xarray.values.squeeze(axis=0),
connectivity=np.ones((3, 3), bool),
mask=watershed_mask
)
# we max-projected and squeezed the Z-plane so label_image.ndim == 2
pixel_ticks: Mapping[Axes, ArrayLike[int]] = {
axis: markers._pixel_ticks[axis]
for axis in (Axes.Y, Axes.X)
}
physical_ticks: Mapping[Coordinates, ArrayLike[Number]] = {
coord: markers._physical_ticks[coord]
for coord in (Coordinates.Y, Coordinates.X)
}

return res
return BinaryMaskCollection.from_label_array_and_ticks(
res,
pixel_ticks,
physical_ticks,
None, # TODO: (ttung) this should really be logged.
)

def show(self, figsize=(10, 10)):
import matplotlib.pyplot as plt
Expand All @@ -297,19 +311,29 @@ def show(self, figsize=(10, 10)):
plt.title('Nuclei Thresholded')

plt.subplot(324)
image(self.mask, bar=False, ax=plt.gca())
image(
self.mask.to_label_image().xarray.squeeze(Axes.ZPLANE.value).values,
bar=False,
ax=plt.gca(),
)
plt.title('Watershed Mask')

plt.subplot(325)
image(
self.markers.xarray.squeeze(Axes.ZPLANE.value).values,
self.markers.to_label_image().xarray.squeeze(Axes.ZPLANE.value).values,
size=20,
cmap=plt.cm.nipy_spectral,
ax=plt.gca())
ax=plt.gca(),
)
plt.title('Found: {} cells'.format(self.num_cells))

plt.subplot(326)
image(self.segmented, size=20, cmap=plt.cm.nipy_spectral, ax=plt.gca())
image(
self.segmented.to_label_image().xarray.values,
size=20,
cmap=plt.cm.nipy_spectral,
ax=plt.gca(),
)
plt.title('Segmented Cells')

return plt.gca()

0 comments on commit e666d82

Please sign in to comment.