diff --git a/plantseg/segmentation/functional/segmentation.py b/plantseg/segmentation/functional/segmentation.py index abc3a1a5..32036335 100644 --- a/plantseg/segmentation/functional/segmentation.py +++ b/plantseg/segmentation/functional/segmentation.py @@ -1,9 +1,9 @@ +from typing import Optional + import nifty -import nifty.graph.rag as nrag import numpy as np from elf.segmentation import GaspFromAffinities -from elf.segmentation import stacked_watershed, lifted_multicut as lmc, \ - project_node_labels_to_pixels +from elf.segmentation import stacked_watershed, lifted_multicut as lmc, project_node_labels_to_pixels from elf.segmentation.features import compute_rag, lifted_problem_from_probabilities, lifted_problem_from_segmentation from elf.segmentation.multicut import multicut_kernighan_lin from elf.segmentation.watershed import distance_transform_watershed, apply_size_filter @@ -19,101 +19,99 @@ sitk_installed = False -def dt_watershed(boundary_pmaps: np.ndarray, - threshold: float = 0.5, - sigma_seeds: float = 1., - stacked: bool = False, - sigma_weights: float = 2., - min_size: int = 100, - alpha: float = 1.0, - pixel_pitch: tuple[int, ...] = None, - apply_nonmax_suppression: bool = False, - n_threads: int = None, - mask: np.ndarray = None) -> np.ndarray: - """ Wrapper around elf.distance_transform_watershed +def dt_watershed( + boundary_pmaps: np.ndarray, + threshold: float = 0.5, + sigma_seeds: float = 1.0, + stacked: bool = False, + sigma_weights: float = 2.0, + min_size: int = 100, + alpha: float = 1.0, + pixel_pitch: Optional[tuple[int, ...]] = None, + apply_nonmax_suppression: bool = False, + n_threads: Optional[int] = None, + mask: Optional[np.ndarray] = None, +) -> np.ndarray: + """Performs watershed segmentation using distance transforms on boundary probability maps. Args: - boundary_pmaps (np.ndarray): input height map. - threshold (float): value for the threshold applied before distance transform. - sigma_seeds (float): smoothing factor for the watershed seed map. - stacked (bool): if true the ws will be executed in 2D slice by slice, otherwise in 3D. - sigma_weights (float): smoothing factor for the watershed weight map (default: 2). - min_size (int): minimal size of watershed segments (default: 100) - alpha (float): alpha used to blend input_ and distance_transform in order to obtain the - watershed weight map (default: .9) - pixel_pitch (list-like[int]): anisotropy factor used to compute the distance transform (default: None) - apply_nonmax_suppression (bool): whether to apply non-maximum suppression to filter out seeds. - Needs nifty. (default: False) - n_threads (int): if not None, parallelize the 2D stacked ws. (default: None) - mask (np.ndarray) + boundary_pmaps (np.ndarray): Input height maps, typically boundary probability maps from a CNN. + threshold (float): Threshold applied to boundary maps before distance transform. + sigma_seeds (float): Smoothing factor for the watershed seed map.. + stacked (bool): If True, performs watershed slice-by-slice (2D), otherwise in 3D. + sigma_weights (float): Smoothing factor for the watershed weight map. + min_size (int): Minimal size of watershed segments. + alpha (float): Alpha blending factor used to combine the input and distance transform into the watershed weight map. + pixel_pitch (Optional[tuple[int, ...]]): Pixel pitch to use for anisotropic distance calculation. + apply_nonmax_suppression (bool): If True, applies non-maximum suppression to filter out seeds. Needs nifty. + n_threads (Optional[int]): Number of threads for parallel processing, applicable in 2D mode. + mask (Optional[np.ndarray]): Mask array to exclude certain regions from segmentation. Returns: - segmentation (np.ndarray): watershed segmentation - """ + np.ndarray: The labeled segmentation map from the watershed algorithm. + """ + # Prepare the keyword arguments for the watershed function boundary_pmaps = boundary_pmaps.astype('float32') - ws_kwargs = dict(threshold=threshold, sigma_seeds=sigma_seeds, - sigma_weights=sigma_weights, - min_size=min_size, alpha=alpha, - pixel_pitch=pixel_pitch, - apply_nonmax_suppression=apply_nonmax_suppression, - mask=mask) + ws_kwargs = { + "threshold": threshold, + "sigma_seeds": sigma_seeds, + "sigma_weights": sigma_weights, + "min_size": min_size, + "alpha": alpha, + "pixel_pitch": pixel_pitch, + "apply_nonmax_suppression": apply_nonmax_suppression, + "mask": mask, + } if stacked: - # WS in 2D - ws, _ = stacked_watershed(boundary_pmaps, - ws_function=distance_transform_watershed, - n_threads=n_threads, - **ws_kwargs) + # Apply watershed in 2D, slice by slice + segmentation, _ = stacked_watershed( + boundary_pmaps, ws_function=distance_transform_watershed, n_threads=n_threads, **ws_kwargs + ) else: - # WS in 3D - ws, _ = distance_transform_watershed(boundary_pmaps, **ws_kwargs) + # Apply watershed in 3D + segmentation, _ = distance_transform_watershed(boundary_pmaps, **ws_kwargs) - return ws + return segmentation -def gasp(boundary_pmaps: np.ndarray, - superpixels: np.ndarray = None, - gasp_linkage_criteria: str = 'average', - beta: float = 0.5, - post_minsize: int = 100, - n_threads: int = 6) -> np.ndarray: +def gasp( + boundary_pmaps: np.ndarray, + superpixels: Optional[np.ndarray] = None, + gasp_linkage_criteria: str = 'average', + beta: float = 0.5, + post_minsize: int = 100, + n_threads: int = 6, +) -> np.ndarray: """ - Implementation of the GASP algorithm for segmentation from affinities. + Perform segmentation using the GASP algorithm with affinity maps. Args: - boundary_pmaps (np.ndarray): cell boundary predictions. - superpixels (np.ndarray): superpixel segmentation. If None, GASP will be run from the pixels. (default: None) - gasp_linkage_criteria (str): Linkage criteria for GASP. (default: 'average') - beta (float): beta parameter for GASP. A small value will steer the segmentation towards under-segmentation. - While a high-value bias the segmentation towards the over-segmentation. (default: 0.5) - post_minsize (int): minimal size of the segments after GASP. (default: 100) - n_threads (int): number of threads used for GASP. (default: 6) + boundary_pmaps (np.ndarray): Cell boundary predictions. + superpixels (Optional[np.ndarray]): Superpixel segmentation. If None, GASP will be run from the pixels. Default is None. + gasp_linkage_criteria (str): Linkage criteria for GASP. Default is 'average'. + beta (float): Beta parameter for GASP. Small values steer towards under-segmentation, while high values bias towards over-segmentation. Default is 0.5. + post_minsize (int): Minimum size of the segments after GASP. Default is 100. + n_threads (int): Number of threads used for GASP. Default is 6. Returns: - segmentation (np.ndarray): GASP output segmentation - + np.ndarray: GASP output segmentation. """ if superpixels is not None: - assert boundary_pmaps.shape == superpixels.shape - - if superpixels.ndim == 2: + assert boundary_pmaps.shape == superpixels.shape, "Shape mismatch between boundary_pmaps and superpixels." + if superpixels.ndim == 2: # Ensure superpixels is 3D if provided superpixels = superpixels[None, ...] - def superpixel_gen(*args, **kwargs): - return superpixels - else: - superpixel_gen = None - - if boundary_pmaps.ndim == 2: - boundary_pmaps = boundary_pmaps[None, ...] + # Prepare the arguments for running GASP + run_GASP_kwargs = { + 'linkage_criteria': gasp_linkage_criteria, + 'add_cannot_link_constraints': False, + 'use_efficient_implementations': False, + } - run_GASP_kwargs = {'linkage_criteria': gasp_linkage_criteria, - 'add_cannot_link_constraints': False, - 'use_efficient_implementations': False} - - # pmaps are interpreted as affinities + # Interpret boundary_pmaps as affinities and prepare for GASP boundary_pmaps = boundary_pmaps.astype('float32') - affinities = np.stack([boundary_pmaps, boundary_pmaps, boundary_pmaps], axis=0) + affinities = np.stack([boundary_pmaps] * 3, axis=0) offsets = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] # Shift is required to correct aligned affinities @@ -122,26 +120,30 @@ def superpixel_gen(*args, **kwargs): # invert affinities affinities = 1 - affinities - # Init and run Gasp - gasp_instance = GaspFromAffinities(offsets, - superpixel_generator=superpixel_gen, - run_GASP_kwargs=run_GASP_kwargs, - n_threads=n_threads, - beta_bias=beta) - # running gasp + # Initialize and run GASP + gasp_instance = GaspFromAffinities( + offsets, + superpixel_generator=None if superpixels is None else (lambda *args, **kwargs: superpixels), + run_GASP_kwargs=run_GASP_kwargs, + n_threads=n_threads, + beta_bias=beta, + ) segmentation, _ = gasp_instance(affinities) - # init and run size threshold + # Apply size filtering if specified if post_minsize > 0: segmentation, _ = apply_size_filter(segmentation.astype('uint32'), boundary_pmaps, post_minsize) + return segmentation -def mutex_ws(boundary_pmaps: np.ndarray, - superpixels: np.ndarray = None, - beta: float = 0.5, - post_minsize: int = 100, - n_threads: int = 6) -> np.ndarray: +def mutex_ws( + boundary_pmaps: np.ndarray, + superpixels: Optional[np.ndarray] = None, + beta: float = 0.5, + post_minsize: int = 100, + n_threads: int = 6, +) -> np.ndarray: """ Wrapper around gasp with mutex_watershed as linkage criteria. @@ -158,19 +160,19 @@ def mutex_ws(boundary_pmaps: np.ndarray, segmentation (np.ndarray): MutexWS output segmentation """ - return gasp(boundary_pmaps=boundary_pmaps, - superpixels=superpixels, - gasp_linkage_criteria='mutex_watershed', - beta=beta, - post_minsize=post_minsize, - n_threads=n_threads) - - -def multicut(boundary_pmaps: np.ndarray, - superpixels: np.ndarray, - beta: float = 0.5, - post_minsize: int = 50) -> np.ndarray: - + return gasp( + boundary_pmaps=boundary_pmaps, + superpixels=superpixels, + gasp_linkage_criteria='mutex_watershed', + beta=beta, + post_minsize=post_minsize, + n_threads=n_threads, + ) + + +def multicut( + boundary_pmaps: np.ndarray, superpixels: np.ndarray, beta: float = 0.5, post_minsize: int = 50 +) -> np.ndarray: """ Multicut segmentation from boundary predictions. @@ -201,17 +203,17 @@ def multicut(boundary_pmaps: np.ndarray, # run size threshold if post_minsize > 0: - segmentation, _ = apply_size_filter(segmentation.astype('uint32'), - boundary_pmaps, - post_minsize) + segmentation, _ = apply_size_filter(segmentation.astype('uint32'), boundary_pmaps, post_minsize) return segmentation -def lifted_multicut_from_nuclei_pmaps(boundary_pmaps: np.ndarray, - nuclei_pmaps: np.ndarray, - superpixels: np.ndarray, - beta: float = 0.5, - post_minsize: int = 50) -> np.ndarray: +def lifted_multicut_from_nuclei_pmaps( + boundary_pmaps: np.ndarray, + nuclei_pmaps: np.ndarray, + superpixels: np.ndarray, + beta: float = 0.5, + post_minsize: int = 50, +) -> np.ndarray: """ Lifted Multicut segmentation from boundary predictions and nuclei predictions. @@ -237,12 +239,12 @@ def lifted_multicut_from_nuclei_pmaps(boundary_pmaps: np.ndarray, # assert nuclei pmaps are floats nuclei_pmaps = nuclei_pmaps.astype('float32') input_maps = [nuclei_pmaps] - assignment_threshold = .9 + assignment_threshold = 0.9 # compute lifted multicut features from boundary pmaps - lifted_uvs, lifted_costs = lifted_problem_from_probabilities(rag, superpixels, - input_maps, assignment_threshold, - graph_depth=4) + lifted_uvs, lifted_costs = lifted_problem_from_probabilities( + rag, superpixels, input_maps, assignment_threshold, graph_depth=4 + ) # solve the full lifted problem using the kernighan lin approximation introduced in # http://openaccess.thecvf.com/content_iccv_2015/html/Keuper_Efficient_Decomposition_of_ICCV_2015_paper.html @@ -255,11 +257,13 @@ def lifted_multicut_from_nuclei_pmaps(boundary_pmaps: np.ndarray, return segmentation -def lifted_multicut_from_nuclei_segmentation(boundary_pmaps: np.ndarray, - nuclei_seg: np.ndarray, - superpixels: np.ndarray, - beta: float = 0.5, - post_minsize: int = 50) -> np.ndarray: +def lifted_multicut_from_nuclei_segmentation( + boundary_pmaps: np.ndarray, + nuclei_seg: np.ndarray, + superpixels: np.ndarray, + beta: float = 0.5, + post_minsize: int = 50, +) -> np.ndarray: """ Lifted Multicut segmentation from boundary predictions and nuclei segmentation. @@ -281,11 +285,15 @@ def lifted_multicut_from_nuclei_segmentation(boundary_pmaps: np.ndarray, boundary_pmaps = boundary_pmaps.astype('float32') costs = compute_mc_costs(boundary_pmaps, rag, beta) max_cost = np.abs(np.max(costs)) - lifted_uvs, lifted_costs = lifted_problem_from_segmentation(rag, superpixels, nuclei_seg, - overlap_threshold=0.2, - graph_depth=4, - same_segment_cost=5 * max_cost, - different_segment_cost=-5 * max_cost) + lifted_uvs, lifted_costs = lifted_problem_from_segmentation( + rag, + superpixels, + nuclei_seg, + overlap_threshold=0.2, + graph_depth=4, + same_segment_cost=5 * max_cost, + different_segment_cost=-5 * max_cost, + ) # solve the full lifted problem using the kernighan lin approximation introduced in # http://openaccess.thecvf.com/content_iccv_2015/html/Keuper_Efficient_Decomposition_of_ICCV_2015_paper.html @@ -299,10 +307,9 @@ def lifted_multicut_from_nuclei_segmentation(boundary_pmaps: np.ndarray, return segmentation -def simple_itk_watershed(boundary_pmaps: np.ndarray, - threshold: float = 0.5, - sigma: float = 1.0, - minsize: int = 100) -> np.ndarray: +def simple_itk_watershed( + boundary_pmaps: np.ndarray, threshold: float = 0.5, sigma: float = 1.0, minsize: int = 100 +) -> np.ndarray: """ Simple itk watershed segmentation. @@ -328,22 +335,19 @@ def simple_itk_watershed(boundary_pmaps: np.ndarray, # Itk watershed + size filtering itk_pmaps = sitk.GetImageFromArray(boundary_pmaps) - itk_segmentation = sitk.MorphologicalWatershed(itk_pmaps, - threshold, - markWatershedLine=False, - fullyConnected=False) + itk_segmentation = sitk.MorphologicalWatershed(itk_pmaps, threshold, markWatershedLine=False, fullyConnected=False) itk_segmentation = sitk.RelabelComponent(itk_segmentation, minsize) segmentation = sitk.GetArrayFromImage(itk_segmentation).astype(np.uint16) return segmentation -def simple_itk_watershed_from_markers(boundary_pmaps: np.ndarray, - seeds: np.ndarray): +def simple_itk_watershed_from_markers(boundary_pmaps: np.ndarray, seeds: np.ndarray): if not sitk_installed: raise ValueError('please install sitk before running this process') itk_pmaps = sitk.GetImageFromArray(boundary_pmaps) itk_seeds = sitk.GetImageFromArray(seeds) - segmentation = sitk.MorphologicalWatershedFromMarkers(itk_pmaps, itk_seeds, markWatershedLine=False, - fullyConnected=False) + segmentation = sitk.MorphologicalWatershedFromMarkers( + itk_pmaps, itk_seeds, markWatershedLine=False, fullyConnected=False + ) return sitk.GetArrayFromImage(segmentation).astype('uint32') diff --git a/plantseg/viewer/widget/dataprocessing.py b/plantseg/viewer/widget/dataprocessing.py index a9888e54..eca2507a 100644 --- a/plantseg/viewer/widget/dataprocessing.py +++ b/plantseg/viewer/widget/dataprocessing.py @@ -14,8 +14,8 @@ from plantseg.dataprocessing.functional.labelprocessing import set_background_to_value from plantseg.viewer.widget.predictions import widget_unet_predictions from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_dt_ws -from plantseg.viewer.widget.utils import return_value_if_widget from plantseg.viewer.widget.utils import ( + return_value_if_widget, start_threading_process, create_layer_name, layer_properties, @@ -24,9 +24,40 @@ from plantseg.models.zoo import model_zoo +class RescaleType(Enum): + NEAREST = (0, "Nearest") + LINEAR = (1, "Linear") + BILINEAR = (2, "Bilinear") + + def __init__(self, int_val, str_val): + self.int_val = int_val + self.str_val = str_val + + @classmethod + def to_choices(cls): + return [(mode.str_val, mode.int_val) for mode in cls] + + +class RescaleModes(Enum): + FROM_FACTOR = "From factor" + TO_LAYER_VOXEL_SIZE = "To layer voxel size" + TO_LAYER_SHAPE = "To layer shape" + TO_MODEL_VOXEL_SIZE = "To model voxel size" + TO_VOXEL_SIZE = "To voxel size" + SET_SHAPE = "To shape" + SET_VOXEL_SIZE = "Set voxel size" + + @classmethod + def to_choices(cls): + return [(mode.value, mode) for mode in RescaleModes] + + @magicgui( call_button="Run Gaussian Smoothing", - image={"label": "Image", "tooltip": "Image layer to apply the smoothing."}, + image={ + "label": "Image", + "tooltip": "Image layer to apply the smoothing.", + }, sigma={ "label": "Sigma", "widget_type": "FloatSlider", @@ -70,32 +101,15 @@ def widget_gaussian_smoothing( ) -class RescaleType(Enum): - nearest = 0 - linear = 1 - bilinear = 2 - - -class RescaleModes(Enum): - from_factor = "From factor" - to_layer_voxel_size = "To layer voxel size" - to_layer_shape = "To layer shape" - to_model_voxel_size = "To model voxel size" - to_voxel_size = "To voxel size" - set_shape = "To shape" - set_voxel_size = "Set voxel size" - - -RESCALE_MODES = [mode.value for mode in RescaleModes] - - @magicgui( call_button="Run Image Rescaling", - image={"label": "Image or Label", "tooltip": "Layer to apply the rescaling."}, + image={ + "label": "Image or Label", + "tooltip": "Layer to apply the rescaling.", + }, mode={ "label": "Rescale mode", - "choices": RESCALE_MODES, - "tooltip": f"Select the mode to rescale the image or label.", + "choices": RescaleModes.to_choices(), }, rescaling_factor={ "label": "Rescaling factor", @@ -117,30 +131,32 @@ class RescaleModes(Enum): "tooltip": "Rescale to same voxel size as selected model.", "choices": model_zoo.list_models(), }, - reference_shape={"label": "Out shape", "tooltip": "Rescale to a manually selected shape."}, + reference_shape={ + "label": "Out shape", + "tooltip": "Rescale to a manually selected shape.", + }, order={ "label": "Interpolation order", "widget_type": "ComboBox", - "choices": RescaleType, + "choices": RescaleType.to_choices(), "tooltip": "0 for nearest neighbours (default for labels), 1 for linear, 2 for bilinear.", }, ) def widget_rescaling( viewer: Viewer, image: Layer, - mode: str = RESCALE_MODES[0], + mode: RescaleModes = RescaleModes.FROM_FACTOR, rescaling_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), out_voxel_size: Tuple[float, float, float] = (1.0, 1.0, 1.0), reference_layer: Union[Layer, None] = None, reference_model: str = model_zoo.list_models()[0], reference_shape: Tuple[int, int, int] = (1, 1, 1), - order=RescaleType.linear, + order: int = 0, ) -> Future[LayerDataTuple]: """Rescale an image or label layer to a new voxel size or shape.""" if isinstance(image, Image): layer_type = "image" - order = order.value elif isinstance(image, Labels): layer_type = "labels" @@ -153,48 +169,72 @@ def widget_rescaling( if image.data.ndim == 2: rescaling_factor = (1.0,) + rescaling_factor[1:] + assert ( + len(rescaling_factor) == 3 + ), "Rescaling factor must be a tuple of 3 elements. Please submit an issue on GitHub." + rescaling_factor = float(rescaling_factor[0]), float(rescaling_factor[1]), float(rescaling_factor[2]) current_resolution = image.scale - mode = RescaleModes(mode) # type: ignore match mode: - case RescaleModes.from_factor: - rescaling_factor = tuple(float(x) for x in rescaling_factor) # type: ignore + case RescaleModes.FROM_FACTOR: out_voxel_size = compute_scaling_voxelsize(current_resolution, scaling_factor=rescaling_factor) - case RescaleModes.to_layer_voxel_size: - assert reference_layer is not None, "Please select a reference layer to rescale to." + case RescaleModes.TO_LAYER_VOXEL_SIZE: + if reference_layer is None: + raise ValueError("Please select a reference layer to rescale to.") + out_voxel_size = reference_layer.scale rescaling_factor = compute_scaling_factor(current_resolution, out_voxel_size) - case RescaleModes.to_model_voxel_size: - out_voxel_size = model_zoo.get_model_resolution(reference_model) # type: ignore - if out_voxel_size is None: + case RescaleModes.TO_MODEL_VOXEL_SIZE: + model_voxel_size = model_zoo.get_model_resolution(reference_model) + if model_voxel_size is None: raise ValueError(f"Model {reference_model} does not have a resolution defined.") - rescaling_factor = compute_scaling_factor(current_resolution, out_voxel_size) + rescaling_factor = compute_scaling_factor(current_resolution, model_voxel_size) - case RescaleModes.to_voxel_size: - out_voxel_size = tuple(float(x) for x in out_voxel_size) # type: ignore + case RescaleModes.TO_VOXEL_SIZE: rescaling_factor = compute_scaling_factor(current_resolution, out_voxel_size) - case RescaleModes.to_layer_shape: - assert reference_layer is not None, "Please select a reference layer to rescale to." + case RescaleModes.TO_LAYER_SHAPE: + if reference_layer is None: + raise ValueError("Please select a reference layer to rescale to.") current_shape = image.data.shape out_shape = reference_layer.data.shape - rescaling_factor = tuple(o / c for o, c in zip(out_shape, current_shape)) # type: ignore - out_voxel_size = tuple(i / s for i, s in zip(current_resolution, rescaling_factor)) # type: ignore + assert len(out_shape) == 3, "Reference layer must be a 3D layer. Please submit an issue on GitHub." + assert len(current_shape) == 3, "Current layer must be a 3D layer. Please submit an issue on GitHub." + rescaling_factor = ( + out_shape[0] / current_shape[0], + out_shape[1] / current_shape[1], + out_shape[2] / current_shape[2], + ) + out_voxel_size = ( + current_resolution[0] / rescaling_factor[0], + current_resolution[1] / rescaling_factor[1], + current_resolution[2] / rescaling_factor[2], + ) - case RescaleModes.set_shape: + case RescaleModes.SET_SHAPE: current_shape = image.data.shape out_shape = reference_shape - rescaling_factor = tuple(o / c for o, c in zip(out_shape, current_shape)) # type: ignore - out_voxel_size = tuple(i / s for i, s in zip(current_resolution, rescaling_factor)) # type: ignore + assert len(out_shape) == 3, "Reference layer must be a 3D layer. Please submit an issue on GitHub." + assert len(current_shape) == 3, "Current layer must be a 3D layer. Please submit an issue on GitHub." + rescaling_factor = ( + out_shape[0] / current_shape[0], + out_shape[1] / current_shape[1], + out_shape[2] / current_shape[2], + ) + out_voxel_size = ( + current_resolution[0] / rescaling_factor[0], + current_resolution[1] / rescaling_factor[1], + current_resolution[2] / rescaling_factor[2], + ) # This is the only case where we don't need to rescale the image data # we just need to update the metadata, no need to add this to the DAG. # Maybe this will change in the future implementation of the headless mode. - case RescaleModes.set_voxel_size: - out_voxel_size = tuple(float(x) for x in out_voxel_size) # type: ignore + case RescaleModes.SET_VOXEL_SIZE: + out_voxel_size = float(out_voxel_size[0]), float(out_voxel_size[1]), float(out_voxel_size[2]) image.scale = out_voxel_size result = Future() result.set_result( @@ -252,7 +292,7 @@ def widget_rescaling( @widget_rescaling.mode.changed.connect -def _rescale_update_visibility(mode: str): +def _rescale_update_visibility(mode: RescaleModes): mode = return_value_if_widget(mode) all_widgets = [ @@ -266,27 +306,26 @@ def _rescale_update_visibility(mode: str): for widget in all_widgets: widget.hide() - mode = RescaleModes(mode) # type: ignore match mode: - case RescaleModes.from_factor: + case RescaleModes.FROM_FACTOR: widget_rescaling.rescaling_factor.show() - case RescaleModes.to_layer_voxel_size: + case RescaleModes.TO_LAYER_VOXEL_SIZE: widget_rescaling.reference_layer.show() - case RescaleModes.to_model_voxel_size: + case RescaleModes.TO_MODEL_VOXEL_SIZE: widget_rescaling.reference_model.show() - case RescaleModes.to_voxel_size: + case RescaleModes.TO_VOXEL_SIZE: widget_rescaling.out_voxel_size.show() - case RescaleModes.to_layer_shape: + case RescaleModes.TO_LAYER_SHAPE: widget_rescaling.reference_layer.show() - case RescaleModes.set_shape: + case RescaleModes.SET_SHAPE: widget_rescaling.reference_shape.show() - case RescaleModes.set_voxel_size: + case RescaleModes.SET_VOXEL_SIZE: widget_rescaling.out_voxel_size.show() case _: @@ -311,22 +350,22 @@ def _on_rescaling_image_changed(image: Layer): widget_rescaling.reference_shape[i].value = shape if isinstance(image, Labels): - widget_rescaling.order.value = RescaleType.nearest + widget_rescaling.order.value = RescaleType.NEAREST.int_val @widget_rescaling.order.changed.connect -def _on_rescale_order_changed(order: RescaleType): +def _on_rescale_order_changed(order): order = return_value_if_widget(order) current_image = widget_rescaling.image.value if current_image is None: return None - if isinstance(current_image, Labels) and order != RescaleType.nearest: + if isinstance(current_image, Labels) and order != RescaleType.NEAREST.int_val: napari_formatted_logging( "Labels can only be rescaled with nearest interpolation", thread="Rescaling", level="warning" ) - widget_rescaling.order.value = RescaleType.nearest + widget_rescaling.order.value = RescaleType.NEAREST.int_val def _compute_slices(rectangle, crop_z, shape): @@ -353,7 +392,10 @@ def _cropping(data, crop_slices): @magicgui( call_button="Run Cropping", - image={"label": "Image or Label", "tooltip": "Layer to apply the rescaling."}, + image={ + "label": "Image or Label", + "tooltip": "Layer to apply the rescaling.", + }, crop_roi={ "label": "Crop ROI", "tooltip": "This must be a shape layer with a rectangle XY overlaying the area to crop.", @@ -461,7 +503,12 @@ def _two_layers_operation(data1, data2, operation, weights: float = 0.5): "orientation": "horizontal", "choices": ["Mean", "Maximum", "Minimum"], }, - weights={"label": "Mean weights", "widget_type": "FloatSlider", "max": 1.0, "min": 0.0}, + weights={ + "label": "Mean weights", + "widget_type": "FloatSlider", + "max": 1.0, + "min": 0.0, + }, ) def widget_add_layers( viewer: Viewer, @@ -517,8 +564,14 @@ def _label_processing(segmentation, set_bg_to_0, relabel_segmentation): @magicgui( call_button="Run Label processing", - segmentation={"label": "Segmentation", "tooltip": "Segmentation can be any label layer."}, - set_bg_to_0={"label": "Set background to 0", "tooltip": "Set the largest idx in the image to zero."}, + segmentation={ + "label": "Segmentation", + "tooltip": "Segmentation can be any label layer.", + }, + set_bg_to_0={ + "label": "Set background to 0", + "tooltip": "Set the largest idx in the image to zero.", + }, relabel_segmentation={ "label": "Relabel Segmentation", "tooltip": "Relabel segmentation contiguously to avoid labels clash.", diff --git a/plantseg/viewer/widget/segmentation.py b/plantseg/viewer/widget/segmentation.py index 7de7e6b4..97ceacde 100644 --- a/plantseg/viewer/widget/segmentation.py +++ b/plantseg/viewer/widget/segmentation.py @@ -1,14 +1,13 @@ from concurrent.futures import Future -from enum import Enum -from typing import Tuple, Callable +from typing import Callable, Optional from magicgui import magicgui from napari import Viewer from napari.layers import Labels, Image, Layer from napari.types import LayerDataTuple -from plantseg.dataprocessing.functional.advanced_dataprocessing import fix_over_under_segmentation_from_nuclei, \ - remove_false_positives_by_foreground_probability +from plantseg.dataprocessing.functional.advanced_dataprocessing import fix_over_under_segmentation_from_nuclei +from plantseg.dataprocessing.functional.advanced_dataprocessing import remove_false_positives_by_foreground_probability from plantseg.dataprocessing.functional.dataprocessing import normalize_01 from plantseg.segmentation.functional import gasp, multicut, dt_watershed, mutex_ws from plantseg.segmentation.functional import lifted_multicut_from_nuclei_segmentation, lifted_multicut_from_nuclei_pmaps @@ -16,72 +15,86 @@ from plantseg.viewer.widget.proofreading.proofreading import widget_split_and_merge_from_scribbles from plantseg.viewer.widget.utils import start_threading_process, create_layer_name, layer_properties +STACKED = [('2D', True), ('3D', False)] -def _pmap_warn(thread: str): - napari_formatted_logging('Pmap/Image layer appears to be a raw image and not a pmap. For the best segmentation ' - 'results, try to use a boundaries probability layer ' - '(e.g. from the Run Prediction widget)', - thread=thread, level='warning') - - -class ClusteringOptions(Enum): - gasp = gasp - multicut = multicut - mutex_ws = mutex_ws - -def _generic_clustering(image: Image, labels: Labels, - beta: float = 0.5, - minsize: int = 100, - name: str = 'GASP', - agg_func: Callable = gasp, - viewer: Viewer = None) -> Future[LayerDataTuple]: +def _pmap_warn(thread: str): + napari_formatted_logging( + 'Pmap/Image layer appears to be a raw image and not a pmap. For the best segmentation results, ' + 'try to use a boundaries probability layer (e.g. from the Run Prediction widget)', + thread=thread, + level='warning', + ) + + +def _generic_clustering( + image: Image, + labels: Labels, + beta: float = 0.5, + minsize: int = 100, + name: str = 'GASP', + agg_func: Callable = gasp, + viewer: Optional[Viewer] = None, +) -> Future[LayerDataTuple]: if 'pmap' not in image.metadata: _pmap_warn(f'{name} Clustering Widget') out_name = create_layer_name(image.name, name) inputs_names = (image.name, labels.name) - layer_kwargs = layer_properties(name=out_name, - scale=image.scale, - metadata=image.metadata) + layer_kwargs = layer_properties(name=out_name, scale=image.scale, metadata=image.metadata) layer_type = 'labels' step_kwargs = dict(beta=beta, post_minsize=minsize) - return start_threading_process(agg_func, - runtime_kwargs={'boundary_pmaps': image.data, - 'superpixels': labels.data}, - statics_kwargs=step_kwargs, - out_name=out_name, - input_keys=inputs_names, - layer_kwarg=layer_kwargs, - layer_type=layer_type, - step_name=f'{name} Clustering', - viewer=viewer, - widgets_to_update=[ - widget_split_and_merge_from_scribbles.segmentation] - ) - - -@magicgui(call_button='Run Clustering', - image={'label': 'Pmap/Image', - 'tooltip': 'Raw or boundary image to use as input for clustering.'}, - _labels={'label': 'Over-segmentation', - 'tooltip': 'Over-segmentation labels layer to use as input for clustering.'}, - mode={'label': 'Aggl. Mode', - 'choices': ['GASP', 'MutexWS', 'MultiCut'], - 'tooltip': 'Select which agglomeration algorithm to use.' - }, - beta={'label': 'Under/Over segmentation factor', - 'tooltip': 'A low value will increase under-segmentation tendency ' - 'and a large value increase over-segmentation tendency.', - 'widget_type': 'FloatSlider', 'max': 1., 'min': 0.}, - minsize={'label': 'Min-size', - 'tooltip': 'Minimum segment size allowed in voxels.'}) -def widget_agglomeration(viewer: Viewer, - image: Image, _labels: Labels, - mode: str = "GASP", - beta: float = 0.6, - minsize: int = 100) -> Future[LayerDataTuple]: + return start_threading_process( + agg_func, + runtime_kwargs={'boundary_pmaps': image.data, 'superpixels': labels.data}, + statics_kwargs=step_kwargs, + out_name=out_name, + input_keys=inputs_names, + layer_kwarg=layer_kwargs, + layer_type=layer_type, + step_name=f'{name} Clustering', + viewer=viewer, + widgets_to_update=[widget_split_and_merge_from_scribbles.segmentation], + ) + + +@magicgui( + call_button='Run Clustering', + image={ + 'label': 'Pmap/Image', + 'tooltip': 'Raw or boundary image to use as input for clustering.', + }, + _labels={ + 'label': 'Over-segmentation', + 'tooltip': 'Over-segmentation labels layer to use as input for clustering.', + }, + mode={ + 'label': 'Aggl. Mode', + 'choices': ['GASP', 'MutexWS', 'MultiCut'], + 'tooltip': 'Select which agglomeration algorithm to use.', + }, + beta={ + 'label': 'Under/Over segmentation factor', + 'tooltip': 'A low value will increase under-segmentation tendency ' + 'and a large value increase over-segmentation tendency.', + 'widget_type': 'FloatSlider', + 'max': 1.0, + 'min': 0.0, + }, + minsize={ + 'label': 'Min-size', + 'tooltip': 'Minimum segment size allowed in voxels.', + }, +) +def widget_agglomeration( + viewer: Viewer, + image: Image, + _labels: Labels, + mode: str = "GASP", + beta: float = 0.6, + minsize: int = 100, +) -> Future[LayerDataTuple]: if mode == 'GASP': func = gasp @@ -94,24 +107,36 @@ def widget_agglomeration(viewer: Viewer, return _generic_clustering(image, _labels, beta=beta, minsize=minsize, name=mode, agg_func=func, viewer=viewer) -@magicgui(call_button='Run Lifted MultiCut', - image={'label': 'Pmap/Image', - 'tooltip': 'Raw or boundary image to use as input for clustering.'}, - nuclei={'label': 'Nuclei', - 'tooltip': 'Nuclei binary predictions or Nuclei segmentation.'}, - _labels={'label': 'Over-segmentation', - 'tooltip': 'Over-segmentation labels layer to use as input for clustering.'}, - beta={'label': 'Under/Over segmentation factor', - 'tooltip': 'A low value will increase under-segmentation tendency ' - 'and a large value increase over-segmentation tendency.', - 'widget_type': 'FloatSlider', 'max': 1., 'min': 0.}, - minsize={'label': 'Min-size', - 'tooltip': 'Minimum segment size allowed in voxels.'}) -def widget_lifted_multicut(image: Image, - nuclei: Layer, - _labels: Labels, - beta: float = 0.5, - minsize: int = 100) -> Future[LayerDataTuple]: +@magicgui( + call_button='Run Lifted MultiCut', + image={ + 'label': 'Pmap/Image', + 'tooltip': 'Raw or boundary image to use as input for clustering.', + }, + nuclei={ + 'label': 'Nuclei', + 'tooltip': 'Nuclei binary predictions or Nuclei segmentation.', + }, + _labels={ + 'label': 'Over-segmentation', + 'tooltip': 'Over-segmentation labels layer to use as input for clustering.', + }, + beta={ + 'label': 'Under/Over segmentation factor', + 'tooltip': 'A low value will increase under-segmentation tendency ' + 'and a large value increase over-segmentation tendency.', + 'widget_type': 'FloatSlider', + 'max': 1.0, + 'min': 0.0, + }, + minsize={ + 'label': 'Min-size', + 'tooltip': 'Minimum segment size allowed in voxels.', + }, +) +def widget_lifted_multicut( + image: Image, nuclei: Layer, _labels: Labels, beta: float = 0.5, minsize: int = 100 +) -> Future[LayerDataTuple]: if 'pmap' not in image.metadata: _pmap_warn('Lifted MultiCut Widget') @@ -126,126 +151,138 @@ def widget_lifted_multicut(image: Image, out_name = create_layer_name(image.name, 'LiftedMultiCut') inputs_names = (image.name, nuclei.name, _labels.name) - layer_kwargs = layer_properties(name=out_name, - scale=image.scale, - metadata=image.metadata) + layer_kwargs = layer_properties(name=out_name, scale=image.scale, metadata=image.metadata) layer_type = 'labels' step_kwargs = dict(beta=beta, post_minsize=minsize) - return start_threading_process(lmc, - runtime_kwargs={'boundary_pmaps': image.data, - extra_key: nuclei.data, - 'superpixels': _labels.data}, - statics_kwargs=step_kwargs, - out_name=out_name, - input_keys=inputs_names, - layer_kwarg=layer_kwargs, - layer_type=layer_type, - step_name=f'Lifted Multicut Clustering', - ) - - -def dtws_wrapper(boundary_pmaps, - stacked: bool = True, - threshold: float = 0.5, - min_size: int = 100, - sigma_seeds: float = .2, - sigma_weights: float = 2., - alpha: float = 1., - pixel_pitch: Tuple[int, int, int] = (1, 1, 1), - apply_nonmax_suppression: bool = False, - nuclei: bool = False): + return start_threading_process( + lmc, + runtime_kwargs={'boundary_pmaps': image.data, extra_key: nuclei.data, 'superpixels': _labels.data}, + statics_kwargs=step_kwargs, + out_name=out_name, + input_keys=inputs_names, + layer_kwarg=layer_kwargs, + layer_type=layer_type, + step_name='Lifted Multicut Clustering', + ) + + +def dtws_wrapper( + boundary_pmaps, + stacked: bool = True, + threshold: float = 0.5, + min_size: int = 100, + sigma_seeds: float = 0.2, + sigma_weights: float = 2.0, + alpha: float = 1.0, + pixel_pitch: tuple[int, int, int] = (1, 1, 1), + apply_nonmax_suppression: bool = False, + nuclei: bool = False, +): if nuclei: boundary_pmaps = normalize_01(boundary_pmaps) - boundary_pmaps = 1. - boundary_pmaps + boundary_pmaps = 1.0 - boundary_pmaps mask = boundary_pmaps < threshold else: mask = None - return dt_watershed(boundary_pmaps=boundary_pmaps, - threshold=threshold, - min_size=min_size, - stacked=stacked, - sigma_seeds=sigma_seeds, - sigma_weights=sigma_weights, - alpha=alpha, - pixel_pitch=pixel_pitch, - apply_nonmax_suppression=apply_nonmax_suppression, - mask=mask - ) - - -@magicgui(call_button='Run Watershed', - image={'label': 'Pmap/Image', - 'tooltip': 'Raw or boundary image to use as input for Watershed.'}, - stacked={'label': 'Stacked', - 'tooltip': 'Define if the Watershed will run slice by slice (faster) ' - 'or on the full volume (slower).', - 'widget_type': 'RadioButtons', - 'orientation': 'horizontal', - 'choices': ['2D', '3D']}, - threshold={'label': 'Threshold', - 'tooltip': 'A low value will increase over-segmentation tendency ' - 'and a large value increase under-segmentation tendency.', - 'widget_type': 'FloatSlider', 'max': 1., 'min': 0.}, - min_size={'label': 'Min-size', - 'tooltip': 'Minimum segment size allowed in voxels.'}, - # Advanced parameters - show_advanced={'label': 'Show Advanced Parameters', - 'tooltip': 'Show advanced parameters for the Watershed algorithm.', - 'widget_type': 'CheckBox'}, - sigma_seeds={'label': 'Sigma seeds'}, - sigma_weights={'label': 'Sigma weights'}, - alpha={'label': 'Alpha'}, - use_pixel_pitch={'label': 'Use pixel pitch'}, - pixel_pitch={'label': 'Pixel pitch'}, - apply_nonmax_suppression={'label': 'Apply nonmax suppression'}, - nuclei={'label': 'Is image Nuclei'} - ) -def widget_dt_ws(image: Image, - stacked: str = '2D', - threshold: float = 0.5, - min_size: int = 100, - show_advanced: bool = False, - sigma_seeds: float = .2, - sigma_weights: float = 2., - alpha: float = 1., - use_pixel_pitch: bool = False, - pixel_pitch: Tuple[int, int, int] = (1, 1, 1), - apply_nonmax_suppression: bool = False, - nuclei: bool = False) -> Future[LayerDataTuple]: + return dt_watershed( + boundary_pmaps=boundary_pmaps, + threshold=threshold, + min_size=min_size, + stacked=stacked, + sigma_seeds=sigma_seeds, + sigma_weights=sigma_weights, + alpha=alpha, + pixel_pitch=pixel_pitch, + apply_nonmax_suppression=apply_nonmax_suppression, + mask=mask, + ) + + +@magicgui( + call_button='Run Watershed', + image={ + 'label': 'Pmap/Image', + 'tooltip': 'Raw or boundary image to use as input for Watershed.', + }, + stacked={ + 'label': 'Stacked', + 'tooltip': 'Define if the Watershed will run slice by slice (faster) ' 'or on the full volume (slower).', + 'widget_type': 'RadioButtons', + 'orientation': 'horizontal', + 'choices': STACKED, + }, + threshold={ + 'label': 'Threshold', + 'tooltip': 'A low value will increase over-segmentation tendency ' + 'and a large value increase under-segmentation tendency.', + 'widget_type': 'FloatSlider', + 'max': 1.0, + 'min': 0.0, + }, + min_size={ + 'label': 'Min-size', + 'tooltip': 'Minimum segment size allowed in voxels.', + }, + # Advanced parameters + show_advanced={ + 'label': 'Show Advanced Parameters', + 'tooltip': 'Show advanced parameters for the Watershed algorithm.', + 'widget_type': 'CheckBox', + }, + sigma_seeds={'label': 'Sigma seeds'}, + sigma_weights={'label': 'Sigma weights'}, + alpha={'label': 'Alpha'}, + use_pixel_pitch={'label': 'Use pixel pitch'}, + pixel_pitch={'label': 'Pixel pitch'}, + apply_nonmax_suppression={'label': 'Apply nonmax suppression'}, + nuclei={'label': 'Is image Nuclei'}, +) +def widget_dt_ws( + image: Image, + stacked: bool = False, + threshold: float = 0.5, + min_size: int = 100, + show_advanced: bool = False, + sigma_seeds: float = 0.2, + sigma_weights: float = 2.0, + alpha: float = 1.0, + use_pixel_pitch: bool = False, + pixel_pitch: tuple[int, int, int] = (1, 1, 1), + apply_nonmax_suppression: bool = False, + nuclei: bool = False, +) -> Future[LayerDataTuple]: if 'pmap' not in image.metadata: _pmap_warn("Watershed Widget") out_name = create_layer_name(image.name, 'dtWS') inputs_names = (image.name,) - layer_kwargs = layer_properties(name=out_name, - scale=image.scale, - metadata=image.metadata) + layer_kwargs = layer_properties(name=out_name, scale=image.scale, metadata=image.metadata) layer_type = 'labels' - stacked = False if stacked == '3D' else True - pixel_pitch = pixel_pitch if use_pixel_pitch else None - step_kwargs = dict(threshold=threshold, - min_size=min_size, - stacked=stacked, - sigma_seeds=sigma_seeds, - sigma_weights=sigma_weights, - alpha=alpha, - pixel_pitch=pixel_pitch, - apply_nonmax_suppression=apply_nonmax_suppression, - nuclei=nuclei) - - return start_threading_process(dtws_wrapper, - runtime_kwargs={ - 'boundary_pmaps': image.data}, - statics_kwargs=step_kwargs, - out_name=out_name, - input_keys=inputs_names, - layer_kwarg=layer_kwargs, - layer_type=layer_type, - step_name=f'Watershed Segmentation', - ) + step_kwargs = { + 'threshold': threshold, + 'min_size': min_size, + 'stacked': stacked, + 'sigma_seeds': sigma_seeds, + 'sigma_weights': sigma_weights, + 'alpha': alpha, + 'pixel_pitch': pixel_pitch if use_pixel_pitch else None, + 'apply_nonmax_suppression': apply_nonmax_suppression, + 'nuclei': nuclei, + } + + return start_threading_process( + dtws_wrapper, + runtime_kwargs={'boundary_pmaps': image.data}, + statics_kwargs=step_kwargs, + out_name=out_name, + input_keys=inputs_names, + layer_kwarg=layer_kwargs, + layer_type=layer_type, + step_name='Watershed Segmentation', + ) widget_dt_ws.sigma_seeds.hide() @@ -277,67 +314,87 @@ def _on_show_advanced_changed(state: bool): widget_dt_ws.nuclei.hide() -@magicgui(call_button='Run Watershed', - image={'label': 'Pmap/Image', - 'tooltip': 'Raw or boundary image to use as input for Watershed.'}, - stacked={'label': 'Stacked', - 'tooltip': 'Define if the Watershed will run slice by slice (faster) ' - 'or on the full volume (slower).', - 'widget_type': 'RadioButtons', - 'orientation': 'horizontal', - 'choices': ['2D', '3D']}, - threshold={'label': 'Threshold', - 'tooltip': 'A low value will increase over-segmentation tendency ' - 'and a large value increase under-segmentation tendency.', - 'widget_type': 'FloatSlider', 'max': 1., 'min': 0.}, - min_size={'label': 'Min-size', - 'tooltip': 'Minimum segment size allowed in voxels.'}, - ) -def widget_simple_dt_ws(image: Image, - stacked: str = '2D', - threshold: float = 0.5, - min_size: int = 100) -> Future[LayerDataTuple]: +@magicgui( + call_button='Run Watershed', + image={ + 'label': 'Pmap/Image', + 'tooltip': 'Raw or boundary image to use as input for Watershed.', + }, + stacked={ + 'label': 'Stacked', + 'tooltip': 'Define if the Watershed will run slice by slice (faster) ' 'or on the full volume (slower).', + 'widget_type': 'RadioButtons', + 'orientation': 'horizontal', + 'choices': STACKED, + }, + threshold={ + 'label': 'Threshold', + 'tooltip': 'A low value will increase over-segmentation tendency ' + 'and a large value increase under-segmentation tendency.', + 'widget_type': 'FloatSlider', + 'max': 1.0, + 'min': 0.0, + }, + min_size={ + 'label': 'Min-size', + 'tooltip': 'Minimum segment size allowed in voxels.', + }, +) +def widget_simple_dt_ws( + image: Image, + stacked: bool = False, + threshold: float = 0.5, + min_size: int = 100, +) -> Future[LayerDataTuple]: if 'pmap' not in image.metadata: _pmap_warn("Watershed Widget") out_name = create_layer_name(image.name, 'dtWS') inputs_names = (image.name,) - layer_kwargs = layer_properties(name=out_name, - scale=image.scale, - metadata=image.metadata) + layer_kwargs = layer_properties(name=out_name, scale=image.scale, metadata=image.metadata) layer_type = 'labels' - stacked = False if stacked == '3D' else True - step_kwargs = dict(threshold=threshold, - min_size=min_size, - stacked=stacked, - pixel_pitch=None) - - return start_threading_process(dtws_wrapper, - runtime_kwargs={ - 'boundary_pmaps': image.data}, - statics_kwargs=step_kwargs, - out_name=out_name, - input_keys=inputs_names, - layer_kwarg=layer_kwargs, - layer_type=layer_type, - step_name=f'Watershed Segmentation', - ) - - -@magicgui(call_button='Run Segmentation Fix from Nuclei', - cell_segmentation={'label': 'Cell Segmentation'}, - nuclei_segmentation={'label': 'Nuclei Segmentation'}, - boundary_pmaps={'label': 'Boundary Pmap/Image'}, - threshold={'label': 'Threshold', - 'widget_type': 'FloatRangeSlider', 'max': 100, 'min': 0, 'step': 0.1}, - quantile={'label': 'Nuclei Quantile', - 'widget_type': 'FloatRangeSlider', 'max': 100, 'min': 0, 'step': 0.1}) -def widget_fix_over_under_segmentation_from_nuclei(cell_segmentation: Labels, - nuclei_segmentation: Labels, - boundary_pmaps: Image, - threshold=(33, 66), - quantile=(0.1, 99.9)) -> Future[LayerDataTuple]: + step_kwargs = dict(threshold=threshold, min_size=min_size, stacked=stacked, pixel_pitch=None) + + return start_threading_process( + dtws_wrapper, + runtime_kwargs={'boundary_pmaps': image.data}, + statics_kwargs=step_kwargs, + out_name=out_name, + input_keys=inputs_names, + layer_kwarg=layer_kwargs, + layer_type=layer_type, + step_name='Watershed Segmentation', + ) + + +@magicgui( + call_button='Run Segmentation Fix from Nuclei', + cell_segmentation={'label': 'Cell Segmentation'}, + nuclei_segmentation={'label': 'Nuclei Segmentation'}, + boundary_pmaps={'label': 'Boundary Pmap/Image'}, + threshold={ + 'label': 'Threshold', + 'widget_type': 'FloatRangeSlider', + 'max': 100, + 'min': 0, + 'step': 0.1, + }, + quantile={ + 'label': 'Nuclei Quantile', + 'widget_type': 'FloatRangeSlider', + 'max': 100, + 'min': 0, + 'step': 0.1, + }, +) +def widget_fix_over_under_segmentation_from_nuclei( + cell_segmentation: Labels, + nuclei_segmentation: Labels, + boundary_pmaps: Image, + threshold=(33, 66), + quantile=(0.1, 99.9), +) -> Future[LayerDataTuple]: out_name = create_layer_name(cell_segmentation.name, 'NucleiSegFix') threshold_merge, threshold_split = threshold threshold_merge, threshold_split = threshold_merge / 100, threshold_split / 100 @@ -346,60 +403,62 @@ def widget_fix_over_under_segmentation_from_nuclei(cell_segmentation: Labels, if boundary_pmaps is not None: if 'pmap' not in boundary_pmaps.metadata: _pmap_warn("Fix Over/Under Segmentation from Nuclei Widget") - inputs_names = (cell_segmentation.name, - nuclei_segmentation.name, boundary_pmaps.name) - func_kwargs = {'cell_seg': cell_segmentation.data, - 'nuclei_seg': nuclei_segmentation.data, - 'boundary': boundary_pmaps.data} + inputs_names = (cell_segmentation.name, nuclei_segmentation.name, boundary_pmaps.name) + func_kwargs = { + 'cell_seg': cell_segmentation.data, + 'nuclei_seg': nuclei_segmentation.data, + 'boundary': boundary_pmaps.data, + } else: inputs_names = (cell_segmentation.name, nuclei_segmentation.name) - func_kwargs = {'cell_seg': cell_segmentation.data, - 'nuclei_seg': nuclei_segmentation.data} + func_kwargs = {'cell_seg': cell_segmentation.data, 'nuclei_seg': nuclei_segmentation.data} - layer_kwargs = layer_properties(name=out_name, - scale=cell_segmentation.scale, - metadata=cell_segmentation.metadata) + layer_kwargs = layer_properties(name=out_name, scale=cell_segmentation.scale, metadata=cell_segmentation.metadata) layer_type = 'labels' - step_kwargs = dict(threshold_merge=threshold_merge, - threshold_split=threshold_split, quantiles_nuclei=quantile) - - return start_threading_process(fix_over_under_segmentation_from_nuclei, - runtime_kwargs=func_kwargs, - statics_kwargs=step_kwargs, - out_name=out_name, - input_keys=inputs_names, - layer_kwarg=layer_kwargs, - layer_type=layer_type, - step_name='Fix Over / Under Segmentation', - ) - - -@magicgui(call_button='Run Segmentation Fix from Foreground Pmap', - segmentation={'label': 'Segmentation'}, - foreground={'label': 'Foreground Pmap'}, - threshold={'label': 'Threshold', - 'widget_type': 'FloatSlider', 'max': 1., 'min': 0.}) -def widget_fix_false_positive_from_foreground_pmap(segmentation: Labels, - foreground: Image, # TODO: maybe also allow labels - threshold=0.6) -> Future[LayerDataTuple]: + step_kwargs = dict(threshold_merge=threshold_merge, threshold_split=threshold_split, quantiles_nuclei=quantile) + + return start_threading_process( + fix_over_under_segmentation_from_nuclei, + runtime_kwargs=func_kwargs, + statics_kwargs=step_kwargs, + out_name=out_name, + input_keys=inputs_names, + layer_kwarg=layer_kwargs, + layer_type=layer_type, + step_name='Fix Over / Under Segmentation', + ) + + +@magicgui( + call_button='Run Segmentation Fix from Foreground Pmap', + segmentation={'label': 'Segmentation'}, + foreground={'label': 'Foreground Pmap'}, + threshold={ + 'label': 'Threshold', + 'widget_type': 'FloatSlider', + 'max': 1.0, + 'min': 0.0, + }, +) +def widget_fix_false_positive_from_foreground_pmap( + segmentation: Labels, foreground: Image, threshold=0.6 # TODO: maybe also allow labels +) -> Future[LayerDataTuple]: out_name = create_layer_name(segmentation.name, 'FGPmapFix') inputs_names = (segmentation.name, foreground.name) - func_kwargs = {'segmentation': segmentation.data, - 'foreground': foreground.data} + func_kwargs = {'segmentation': segmentation.data, 'foreground': foreground.data} - layer_kwargs = layer_properties(name=out_name, - scale=segmentation.scale, - metadata=segmentation.metadata) + layer_kwargs = layer_properties(name=out_name, scale=segmentation.scale, metadata=segmentation.metadata) layer_type = 'labels' step_kwargs = dict(threshold=threshold) - return start_threading_process(remove_false_positives_by_foreground_probability, - runtime_kwargs=func_kwargs, - statics_kwargs=step_kwargs, - out_name=out_name, - input_keys=inputs_names, - layer_kwarg=layer_kwargs, - layer_type=layer_type, - step_name='Reduce False Positives', - ) + return start_threading_process( + remove_false_positives_by_foreground_probability, + runtime_kwargs=func_kwargs, + statics_kwargs=step_kwargs, + out_name=out_name, + input_keys=inputs_names, + layer_kwarg=layer_kwargs, + layer_type=layer_type, + step_name='Reduce False Positives', + )