Skip to content

Commit

Permalink
refactoring localMaxPeakFinder
Browse files Browse the repository at this point in the history
  • Loading branch information
Shannon Axelrod committed Oct 11, 2019
1 parent e116567 commit c25bbee
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 10 deletions.
17 changes: 12 additions & 5 deletions notebooks/osmFISH.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,22 @@
"metadata": {},
"outputs": [],
"source": [
"from starfish.spots import DetectSpots\n",
"from starfish.spots import DecodeSpots, FindSpots\n",
"from starfish.types import TraceBuildingStrategies\n",
"\n",
"lmp = DetectSpots.LocalMaxPeakFinder(\n",
"\n",
"lmp = FindSpots.LocalMaxPeakFinder(\n",
" min_distance=6,\n",
" stringency=0,\n",
" min_obj_area=6,\n",
" max_obj_area=600,\n",
" is_volume=True\n",
")\n",
"spot_intensities = lmp.run(mp)"
"spots = lmp.run(mp)\n",
"\n",
"decoder = DecodeSpots.PerRoundMaxChannel(codebook=experiment.codebook,\n",
" trace_building_strategy=TraceBuildingStrategies.SEQUENTIAL)\n",
"decoded_intensities = decoder.run(spots=spots)"
]
},
{
Expand Down Expand Up @@ -240,11 +247,11 @@
"outputs": [],
"source": [
"benchmark_spot_count = len(benchmark_peaks)\n",
"starfish_spot_count = len(spot_intensities)\n",
"starfish_spot_count = len(decoded_intensities)\n",
"\n",
"plt.figure(figsize=(10,10))\n",
"plt.plot(benchmark_peaks.x, -benchmark_peaks.y, \"o\")\n",
"plt.plot(spot_intensities[Axes.X.value], -spot_intensities[Axes.Y.value], \"x\")\n",
"plt.plot(decoded_intensities[Axes.X.value], -decoded_intensities[Axes.Y.value], \"x\")\n",
"\n",
"plt.legend([\"Benchmark: {} spots\".format(benchmark_spot_count),\n",
" \"Starfish: {} spots\".format(starfish_spot_count)])\n",
Expand Down
17 changes: 12 additions & 5 deletions notebooks/py/osmFISH.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,22 @@
# EPY: END markdown

# EPY: START code
from starfish.spots import DetectSpots
from starfish.spots import DecodeSpots, FindSpots
from starfish.types import TraceBuildingStrategies

lmp = DetectSpots.LocalMaxPeakFinder(

lmp = FindSpots.LocalMaxPeakFinder(
min_distance=6,
stringency=0,
min_obj_area=6,
max_obj_area=600,
is_volume=True
)
spot_intensities = lmp.run(mp)
spots = lmp.run(mp)

decoder = DecodeSpots.PerRoundMaxChannel(codebook=experiment.codebook,
trace_building_strategy=TraceBuildingStrategies.SEQUENTIAL)
decoded_intensities = decoder.run(spots=spots)
# EPY: END code

# EPY: START markdown
Expand Down Expand Up @@ -162,11 +169,11 @@ def get_benchmark_peaks(loaded_results, redo_flag=False):

# EPY: START code
benchmark_spot_count = len(benchmark_peaks)
starfish_spot_count = len(spot_intensities)
starfish_spot_count = len(decoded_intensities)

plt.figure(figsize=(10,10))
plt.plot(benchmark_peaks.x, -benchmark_peaks.y, "o")
plt.plot(spot_intensities[Axes.X.value], -spot_intensities[Axes.Y.value], "x")
plt.plot(decoded_intensities[Axes.X.value], -decoded_intensities[Axes.Y.value], "x")

plt.legend(["Benchmark: {} spots".format(benchmark_spot_count),
"Starfish: {} spots".format(starfish_spot_count)])
Expand Down
1 change: 1 addition & 0 deletions starfish/core/spots/FindSpots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._base import FindSpotsAlgorithm
from .blob import BlobDetector
from .local_max_peak_finder import LocalMaxPeakFinder
from .trackpy_local_max_peak_finder import TrackpyLocalMaxPeakFinder

# autodoc's automodule directive only captures the modules explicitly listed in __all__.
Expand Down
296 changes: 296 additions & 0 deletions starfish/core/spots/FindSpots/local_max_peak_finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
from functools import partial
from typing import List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import xarray as xr
from scipy.ndimage import label
from skimage.feature import peak_local_max
from skimage.measure import regionprops
from sympy import Line, Point
from tqdm import tqdm

from starfish.core.config import StarfishConfig
from starfish.core.image.Filter.util import determine_axes_to_group_by
from starfish.core.imagestack.imagestack import ImageStack
from starfish.core.spots.FindSpots import spot_finding_utils
from starfish.core.types import Axes, Features, Number, SpotAttributes, SpotFindingResults
from ._base import FindSpotsAlgorithm


class LocalMaxPeakFinder(FindSpotsAlgorithm):
"""
2-dimensional local-max peak finder that wraps skimage.feature.peak_local_max
Parameters
----------
min_distance : int
Minimum number of pixels separating peaks in a region of 2 * min_distance + 1
(i.e. peaks are separated by at least min_distance). To find the maximum number of
peaks, use min_distance=1.
stringency : int
min_obj_area : int
max_obj_area : int
threshold : Optional[Number]
measurement_type : str, {'max', 'mean'}
default 'max' calculates the maximum intensity inside the object
min_num_spots_detected : int
When fewer than this number of spots are detected, spot searching for higher threshold
values. (default = 3)
is_volume : bool
Not supported. For 3d peak detection please use TrackpyLocalMaxPeakFinder.
(default=False)
verbose : bool
If True, report the percentage completed during processing
(default = False)
Notes
-----
http://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.peak_local_max
"""

def __init__(
self, min_distance: int, stringency: int, min_obj_area: int, max_obj_area: int,
threshold: Optional[Number]=None, measurement_type: str='max',
min_num_spots_detected: int=3, is_volume: bool=False, verbose: bool=True
) -> None:

self.min_distance = min_distance
self.stringency = stringency
self.min_obj_area = min_obj_area
self.max_obj_area = max_obj_area
self.threshold = threshold
self.min_num_spots_detected = min_num_spots_detected

self.measurement_function = self._get_measurement_function(measurement_type)

self.is_volume = is_volume
self.verbose = verbose

def _compute_num_spots_per_threshold(self, img: np.ndarray) -> Tuple[np.ndarray, List[int]]:
"""Computes the number of detected spots for each threshold
Parameters
----------
img : np.ndarray
The image in which to count spots
Returns
-------
np.ndarray :
thresholds
List[int] :
spot counts
"""

# thresholds to search over
thresholds = np.linspace(img.min(), img.max(), num=100)

# number of spots detected at each threshold
spot_counts = []

# where we stop our threshold search
stop_threshold = None

if self.verbose and StarfishConfig().verbose:
threshold_iter = tqdm(thresholds)
print('Determining optimal threshold ...')
else:
threshold_iter = thresholds

for stop_index, threshold in enumerate(threshold_iter):
spots = peak_local_max(
img,
min_distance=self.min_distance,
threshold_abs=threshold,
exclude_border=False,
indices=True,
num_peaks=np.inf,
footprint=None,
labels=None
)

# stop spot finding when the number of detected spots falls below min_num_spots_detected
if len(spots) <= self.min_num_spots_detected:
stop_threshold = threshold
if self.verbose:
print(f'Stopping early at threshold={threshold}. Number of spots fell below: '
f'{self.min_num_spots_detected}')
break
else:
spot_counts.append(len(spots))

if stop_threshold is None:
stop_threshold = thresholds.max()

if len(thresholds > 1):
thresholds = thresholds[:stop_index]
spot_counts = spot_counts[:stop_index]

return thresholds, spot_counts

def _select_optimal_threshold(self, thresholds: np.ndarray, spot_counts: List[int]) -> float:
# calculate the gradient of the number of spots
grad = np.gradient(spot_counts)
optimal_threshold_index = np.argmin(grad)

# only consider thresholds > than optimal threshold
thresholds = thresholds[optimal_threshold_index:]
grad = grad[optimal_threshold_index:]

# if all else fails, return 0.
selected_thr = 0

if len(thresholds) > 1:

distances = []

# create a line whose end points are the threshold and and corresponding gradient value
# for spot_counts corresponding to the threshold
start_point = Point(thresholds[0], grad[0])
end_point = Point(thresholds[-1], grad[-1])
line = Line(start_point, end_point)

# calculate the distance between all points and the line
for k in range(len(thresholds)):
p = Point(thresholds[k], grad[k])
dst = line.distance(p)
distances.append(dst.evalf())

# remove the end points
thresholds = thresholds[1:-1]
distances = distances[1:-1]

# select the threshold that has the maximum distance from the line
# if stringency is passed, select a threshold that is n steps higher, where n is the
# value of stringency
if distances:
thr_idx = np.argmax(np.array(distances))

if thr_idx + self.stringency < len(thresholds):
selected_thr = thresholds[thr_idx + self.stringency]
else:
selected_thr = thresholds[thr_idx]

return selected_thr

def _compute_threshold(self, img: Union[np.ndarray, xr.DataArray]) -> float:
"""Finds spots on a number of thresholds then selects and returns the optimal threshold
Parameters
----------
img: Union[np.ndarray, xr.DataArray]
data array in which spots should be detected and over which to compute different
intensity thresholds
Returns
-------
float :
The intensity threshold
"""
img = np.asarray(img)
thresholds, spot_counts = self._compute_num_spots_per_threshold(img)
if len(spot_counts) == 0:
# this only happens when we never find more spots than `self.min_num_spots_detected`
return img.min()
return self._select_optimal_threshold(thresholds, spot_counts)

def image_to_spots(self, data_image: Union[np.ndarray, xr.DataArray]) -> SpotAttributes:
"""measure attributes of spots detected by binarizing the image using the selected threshold
Parameters
----------
data_image : Union[np.ndarray, xr.DataArray]
image containing spots to be detected
Returns
-------
SpotAttributes
Attributes for each detected spot
"""

threshold = self._compute_threshold(data_image)

data_image = np.asarray(data_image)

# identify each spot's size by binarizing and calculating regionprops
masked_image = data_image[:, :] > threshold
labels = label(masked_image)[0]
spot_props = regionprops(np.squeeze(labels))

# mask spots whose areas are too small or too large
for spot_prop in spot_props:
if spot_prop.area < self.min_obj_area or spot_prop.area > self.max_obj_area:
masked_image[0, spot_prop.coords[:, 0], spot_prop.coords[:, 1]] = 0

# store re-calculated regionprops and labels based on the area-masked image
labels = label(masked_image)[0]

if self.verbose:
print('computing final spots ...')

spot_coords = peak_local_max(
data_image,
min_distance=self.min_distance,
threshold_abs=threshold,
exclude_border=False,
indices=True,
num_peaks=np.inf,
footprint=None,
labels=labels
)
res = {Axes.X.value: spot_coords[:, 2],
Axes.Y.value: spot_coords[:, 1],
Axes.ZPLANE.value: spot_coords[:, 0],
Features.SPOT_RADIUS: 1,
Features.SPOT_ID: np.arange(spot_coords.shape[0]),
Features.INTENSITY: data_image[spot_coords[:, 0],
spot_coords[:, 1],
spot_coords[:, 2]]
}

return SpotAttributes(pd.DataFrame(res))

def run(
self,
image_stack: ImageStack,
reference_image: Optional[ImageStack] = None,
n_processes: Optional[int] = None,
*args,
**kwargs
) -> SpotFindingResults:
"""
Find spots in the given ImageStack using a gaussian blob finding algorithm.
If a reference image is provided the spots will be detected there then measured
across all rounds and channels in the corresponding ImageStack. If a reference_image
is not provided spots will be detected _independently_ in each channel. This assumes
a non-multiplex imaging experiment, as only one (ch, round) will be measured for each spot.
Parameters
----------
image_stack : ImageStack
ImageStack where we find the spots in.
reference_image : xr.DataArray
(Optional) a reference image. If provided, spots will be found in this image, and then
the locations that correspond to these spots will be measured across each channel.
n_processes : Optional[int] = None,
Number of processes to devote to spot finding.
"""
spot_finding_method = partial(self.image_to_spots, *args, **kwargs)
if reference_image:
data_image = reference_image._squeezed_numpy(*{Axes.ROUND, Axes.CH})
reference_spots = spot_finding_method(data_image)
results = spot_finding_utils.measure_intensities_at_spot_locations_across_imagestack(
data_image=image_stack,
reference_spots=reference_spots,
measurement_function=self.measurement_function)
else:
spot_attributes_list = image_stack.transform(
func=spot_finding_method,
group_by=determine_axes_to_group_by(self.is_volume),
n_processes=n_processes
)
results = SpotFindingResults(imagestack_coords=image_stack.xarray.coords,
log=image_stack.log,
spot_attributes_list=spot_attributes_list)
return results

0 comments on commit c25bbee

Please sign in to comment.