From c25bbee8a6ab8657e3cc5fdd06b4d60953d852f5 Mon Sep 17 00:00:00 2001 From: Shannon Axelrod Date: Thu, 26 Sep 2019 11:31:22 -0700 Subject: [PATCH] refactoring localMaxPeakFinder --- notebooks/osmFISH.ipynb | 17 +- notebooks/py/osmFISH.py | 17 +- starfish/core/spots/FindSpots/__init__.py | 1 + .../spots/FindSpots/local_max_peak_finder.py | 296 ++++++++++++++++++ 4 files changed, 321 insertions(+), 10 deletions(-) create mode 100644 starfish/core/spots/FindSpots/local_max_peak_finder.py diff --git a/notebooks/osmFISH.ipynb b/notebooks/osmFISH.ipynb index 900428e02..896e33896 100644 --- a/notebooks/osmFISH.ipynb +++ b/notebooks/osmFISH.ipynb @@ -161,15 +161,22 @@ "metadata": {}, "outputs": [], "source": [ - "from starfish.spots import DetectSpots\n", + "from starfish.spots import DecodeSpots, FindSpots\n", + "from starfish.types import TraceBuildingStrategies\n", "\n", - "lmp = DetectSpots.LocalMaxPeakFinder(\n", + "\n", + "lmp = FindSpots.LocalMaxPeakFinder(\n", " min_distance=6,\n", " stringency=0,\n", " min_obj_area=6,\n", " max_obj_area=600,\n", + " is_volume=True\n", ")\n", - "spot_intensities = lmp.run(mp)" + "spots = lmp.run(mp)\n", + "\n", + "decoder = DecodeSpots.PerRoundMaxChannel(codebook=experiment.codebook,\n", + " trace_building_strategy=TraceBuildingStrategies.SEQUENTIAL)\n", + "decoded_intensities = decoder.run(spots=spots)" ] }, { @@ -240,11 +247,11 @@ "outputs": [], "source": [ "benchmark_spot_count = len(benchmark_peaks)\n", - "starfish_spot_count = len(spot_intensities)\n", + "starfish_spot_count = len(decoded_intensities)\n", "\n", "plt.figure(figsize=(10,10))\n", "plt.plot(benchmark_peaks.x, -benchmark_peaks.y, \"o\")\n", - "plt.plot(spot_intensities[Axes.X.value], -spot_intensities[Axes.Y.value], \"x\")\n", + "plt.plot(decoded_intensities[Axes.X.value], -decoded_intensities[Axes.Y.value], \"x\")\n", "\n", "plt.legend([\"Benchmark: {} spots\".format(benchmark_spot_count),\n", " \"Starfish: {} spots\".format(starfish_spot_count)])\n", diff --git a/notebooks/py/osmFISH.py b/notebooks/py/osmFISH.py index 79a564bb1..de261df16 100644 --- a/notebooks/py/osmFISH.py +++ b/notebooks/py/osmFISH.py @@ -102,15 +102,22 @@ # EPY: END markdown # EPY: START code -from starfish.spots import DetectSpots +from starfish.spots import DecodeSpots, FindSpots +from starfish.types import TraceBuildingStrategies -lmp = DetectSpots.LocalMaxPeakFinder( + +lmp = FindSpots.LocalMaxPeakFinder( min_distance=6, stringency=0, min_obj_area=6, max_obj_area=600, + is_volume=True ) -spot_intensities = lmp.run(mp) +spots = lmp.run(mp) + +decoder = DecodeSpots.PerRoundMaxChannel(codebook=experiment.codebook, + trace_building_strategy=TraceBuildingStrategies.SEQUENTIAL) +decoded_intensities = decoder.run(spots=spots) # EPY: END code # EPY: START markdown @@ -162,11 +169,11 @@ def get_benchmark_peaks(loaded_results, redo_flag=False): # EPY: START code benchmark_spot_count = len(benchmark_peaks) -starfish_spot_count = len(spot_intensities) +starfish_spot_count = len(decoded_intensities) plt.figure(figsize=(10,10)) plt.plot(benchmark_peaks.x, -benchmark_peaks.y, "o") -plt.plot(spot_intensities[Axes.X.value], -spot_intensities[Axes.Y.value], "x") +plt.plot(decoded_intensities[Axes.X.value], -decoded_intensities[Axes.Y.value], "x") plt.legend(["Benchmark: {} spots".format(benchmark_spot_count), "Starfish: {} spots".format(starfish_spot_count)]) diff --git a/starfish/core/spots/FindSpots/__init__.py b/starfish/core/spots/FindSpots/__init__.py index 7e206c8b5..4b62b0db3 100644 --- a/starfish/core/spots/FindSpots/__init__.py +++ b/starfish/core/spots/FindSpots/__init__.py @@ -1,5 +1,6 @@ from ._base import FindSpotsAlgorithm from .blob import BlobDetector +from .local_max_peak_finder import LocalMaxPeakFinder from .trackpy_local_max_peak_finder import TrackpyLocalMaxPeakFinder # autodoc's automodule directive only captures the modules explicitly listed in __all__. diff --git a/starfish/core/spots/FindSpots/local_max_peak_finder.py b/starfish/core/spots/FindSpots/local_max_peak_finder.py new file mode 100644 index 000000000..ee9a55bfa --- /dev/null +++ b/starfish/core/spots/FindSpots/local_max_peak_finder.py @@ -0,0 +1,296 @@ +from functools import partial +from typing import List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import xarray as xr +from scipy.ndimage import label +from skimage.feature import peak_local_max +from skimage.measure import regionprops +from sympy import Line, Point +from tqdm import tqdm + +from starfish.core.config import StarfishConfig +from starfish.core.image.Filter.util import determine_axes_to_group_by +from starfish.core.imagestack.imagestack import ImageStack +from starfish.core.spots.FindSpots import spot_finding_utils +from starfish.core.types import Axes, Features, Number, SpotAttributes, SpotFindingResults +from ._base import FindSpotsAlgorithm + + +class LocalMaxPeakFinder(FindSpotsAlgorithm): + """ + 2-dimensional local-max peak finder that wraps skimage.feature.peak_local_max + + Parameters + ---------- + min_distance : int + Minimum number of pixels separating peaks in a region of 2 * min_distance + 1 + (i.e. peaks are separated by at least min_distance). To find the maximum number of + peaks, use min_distance=1. + stringency : int + min_obj_area : int + max_obj_area : int + threshold : Optional[Number] + measurement_type : str, {'max', 'mean'} + default 'max' calculates the maximum intensity inside the object + min_num_spots_detected : int + When fewer than this number of spots are detected, spot searching for higher threshold + values. (default = 3) + is_volume : bool + Not supported. For 3d peak detection please use TrackpyLocalMaxPeakFinder. + (default=False) + verbose : bool + If True, report the percentage completed during processing + (default = False) + + Notes + ----- + http://scikit-image.org/docs/dev/api/skimage.feature.html#skimage.feature.peak_local_max + """ + + def __init__( + self, min_distance: int, stringency: int, min_obj_area: int, max_obj_area: int, + threshold: Optional[Number]=None, measurement_type: str='max', + min_num_spots_detected: int=3, is_volume: bool=False, verbose: bool=True + ) -> None: + + self.min_distance = min_distance + self.stringency = stringency + self.min_obj_area = min_obj_area + self.max_obj_area = max_obj_area + self.threshold = threshold + self.min_num_spots_detected = min_num_spots_detected + + self.measurement_function = self._get_measurement_function(measurement_type) + + self.is_volume = is_volume + self.verbose = verbose + + def _compute_num_spots_per_threshold(self, img: np.ndarray) -> Tuple[np.ndarray, List[int]]: + """Computes the number of detected spots for each threshold + + Parameters + ---------- + img : np.ndarray + The image in which to count spots + + Returns + ------- + np.ndarray : + thresholds + List[int] : + spot counts + """ + + # thresholds to search over + thresholds = np.linspace(img.min(), img.max(), num=100) + + # number of spots detected at each threshold + spot_counts = [] + + # where we stop our threshold search + stop_threshold = None + + if self.verbose and StarfishConfig().verbose: + threshold_iter = tqdm(thresholds) + print('Determining optimal threshold ...') + else: + threshold_iter = thresholds + + for stop_index, threshold in enumerate(threshold_iter): + spots = peak_local_max( + img, + min_distance=self.min_distance, + threshold_abs=threshold, + exclude_border=False, + indices=True, + num_peaks=np.inf, + footprint=None, + labels=None + ) + + # stop spot finding when the number of detected spots falls below min_num_spots_detected + if len(spots) <= self.min_num_spots_detected: + stop_threshold = threshold + if self.verbose: + print(f'Stopping early at threshold={threshold}. Number of spots fell below: ' + f'{self.min_num_spots_detected}') + break + else: + spot_counts.append(len(spots)) + + if stop_threshold is None: + stop_threshold = thresholds.max() + + if len(thresholds > 1): + thresholds = thresholds[:stop_index] + spot_counts = spot_counts[:stop_index] + + return thresholds, spot_counts + + def _select_optimal_threshold(self, thresholds: np.ndarray, spot_counts: List[int]) -> float: + # calculate the gradient of the number of spots + grad = np.gradient(spot_counts) + optimal_threshold_index = np.argmin(grad) + + # only consider thresholds > than optimal threshold + thresholds = thresholds[optimal_threshold_index:] + grad = grad[optimal_threshold_index:] + + # if all else fails, return 0. + selected_thr = 0 + + if len(thresholds) > 1: + + distances = [] + + # create a line whose end points are the threshold and and corresponding gradient value + # for spot_counts corresponding to the threshold + start_point = Point(thresholds[0], grad[0]) + end_point = Point(thresholds[-1], grad[-1]) + line = Line(start_point, end_point) + + # calculate the distance between all points and the line + for k in range(len(thresholds)): + p = Point(thresholds[k], grad[k]) + dst = line.distance(p) + distances.append(dst.evalf()) + + # remove the end points + thresholds = thresholds[1:-1] + distances = distances[1:-1] + + # select the threshold that has the maximum distance from the line + # if stringency is passed, select a threshold that is n steps higher, where n is the + # value of stringency + if distances: + thr_idx = np.argmax(np.array(distances)) + + if thr_idx + self.stringency < len(thresholds): + selected_thr = thresholds[thr_idx + self.stringency] + else: + selected_thr = thresholds[thr_idx] + + return selected_thr + + def _compute_threshold(self, img: Union[np.ndarray, xr.DataArray]) -> float: + """Finds spots on a number of thresholds then selects and returns the optimal threshold + + Parameters + ---------- + img: Union[np.ndarray, xr.DataArray] + data array in which spots should be detected and over which to compute different + intensity thresholds + + Returns + ------- + float : + The intensity threshold + """ + img = np.asarray(img) + thresholds, spot_counts = self._compute_num_spots_per_threshold(img) + if len(spot_counts) == 0: + # this only happens when we never find more spots than `self.min_num_spots_detected` + return img.min() + return self._select_optimal_threshold(thresholds, spot_counts) + + def image_to_spots(self, data_image: Union[np.ndarray, xr.DataArray]) -> SpotAttributes: + """measure attributes of spots detected by binarizing the image using the selected threshold + + Parameters + ---------- + data_image : Union[np.ndarray, xr.DataArray] + image containing spots to be detected + + Returns + ------- + SpotAttributes + Attributes for each detected spot + """ + + threshold = self._compute_threshold(data_image) + + data_image = np.asarray(data_image) + + # identify each spot's size by binarizing and calculating regionprops + masked_image = data_image[:, :] > threshold + labels = label(masked_image)[0] + spot_props = regionprops(np.squeeze(labels)) + + # mask spots whose areas are too small or too large + for spot_prop in spot_props: + if spot_prop.area < self.min_obj_area or spot_prop.area > self.max_obj_area: + masked_image[0, spot_prop.coords[:, 0], spot_prop.coords[:, 1]] = 0 + + # store re-calculated regionprops and labels based on the area-masked image + labels = label(masked_image)[0] + + if self.verbose: + print('computing final spots ...') + + spot_coords = peak_local_max( + data_image, + min_distance=self.min_distance, + threshold_abs=threshold, + exclude_border=False, + indices=True, + num_peaks=np.inf, + footprint=None, + labels=labels + ) + res = {Axes.X.value: spot_coords[:, 2], + Axes.Y.value: spot_coords[:, 1], + Axes.ZPLANE.value: spot_coords[:, 0], + Features.SPOT_RADIUS: 1, + Features.SPOT_ID: np.arange(spot_coords.shape[0]), + Features.INTENSITY: data_image[spot_coords[:, 0], + spot_coords[:, 1], + spot_coords[:, 2]] + } + + return SpotAttributes(pd.DataFrame(res)) + + def run( + self, + image_stack: ImageStack, + reference_image: Optional[ImageStack] = None, + n_processes: Optional[int] = None, + *args, + **kwargs + ) -> SpotFindingResults: + """ + Find spots in the given ImageStack using a gaussian blob finding algorithm. + If a reference image is provided the spots will be detected there then measured + across all rounds and channels in the corresponding ImageStack. If a reference_image + is not provided spots will be detected _independently_ in each channel. This assumes + a non-multiplex imaging experiment, as only one (ch, round) will be measured for each spot. + + Parameters + ---------- + image_stack : ImageStack + ImageStack where we find the spots in. + reference_image : xr.DataArray + (Optional) a reference image. If provided, spots will be found in this image, and then + the locations that correspond to these spots will be measured across each channel. + n_processes : Optional[int] = None, + Number of processes to devote to spot finding. + """ + spot_finding_method = partial(self.image_to_spots, *args, **kwargs) + if reference_image: + data_image = reference_image._squeezed_numpy(*{Axes.ROUND, Axes.CH}) + reference_spots = spot_finding_method(data_image) + results = spot_finding_utils.measure_intensities_at_spot_locations_across_imagestack( + data_image=image_stack, + reference_spots=reference_spots, + measurement_function=self.measurement_function) + else: + spot_attributes_list = image_stack.transform( + func=spot_finding_method, + group_by=determine_axes_to_group_by(self.is_volume), + n_processes=n_processes + ) + results = SpotFindingResults(imagestack_coords=image_stack.xarray.coords, + log=image_stack.log, + spot_attributes_list=spot_attributes_list) + return results