diff --git a/README.md b/README.md index 8993de11b..f87766398 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ We have recorded workshop talks which complement the repository. [MIBI Workshop #### 1. Segmentation The [**segmentation notebook**](./templates/1_Segment_Image_Data.ipynb) will walk you through the process of using [Mesmer](https://www.nature.com/articles/s41587-021-01094-0) to segment your image data. This includes selecting the appropriate channel(s) for segmentation, running your data through the network, and then extracting single-cell statistics from the resulting segmentation mask. [Workshop Talk - Session V - Part 1: Segmentation](https://youtu.be/4_AJxrxPYlk?t=231) - *Note:* It is assumed that the cell table uses the default column names as in `ark/settings.py`. Refer to the [docs](docs/_rtd/data_types.md) to get descriptions of the cell table columns, and methods to adjust them if necessary. + - If you plan to segment out non-traditional cellular structures such as protein aggregates or cytoplasmic projections often found in brain cells (e.g. microglia, astrocytes, and neuropil), try out the companion notebook [**ezSegmenter**](./templates/ez_segmenter.ipynb) either as a stand-alone or in combination with the above standard cell segmentation process. #### 2. Pixel clustering with Pixie The first step in the [Pixie](https://doi.org/10.1038/s41467-023-40068-5) pipeline is to run the [**pixel clustering notebook**](./templates/2_Pixie_Cluster_Pixels.ipynb). The notebook walks you through the process of generating pixel clusters for your data, and lets you specify what markers to use for the clustering, train a model, use it to classify your entire dataset, and generate pixel cluster overlays. The notebook includes a GUI for manual cluster adjustment and annotation. [Workshop Talk - Session IV - Pixel Level Analysis](https://youtu.be/e7C1NvaPLaY) diff --git a/conftest.py b/conftest.py index c37a4978f..4a15b53e5 100644 --- a/conftest.py +++ b/conftest.py @@ -16,7 +16,7 @@ def dataset_cache_dir() -> Iterator[Union[str, None]]: yield cache_dir -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def rng() -> Generator[np.random.Generator, None, None]: """ Create a new Random Number Generator for tests which require randomized data. diff --git a/src/ark/segmentation/ez_seg/__init__.py b/src/ark/segmentation/ez_seg/__init__.py new file mode 100644 index 000000000..711b05c08 --- /dev/null +++ b/src/ark/segmentation/ez_seg/__init__.py @@ -0,0 +1,10 @@ +from . import composites, merge_masks, ez_seg_display, ez_seg_utils +from .ez_object_segmentation import create_object_masks + +__all__ = [ + "composites", + "merge_masks", + "ez_seg_display", + "ez_seg_utils", + "create_object_masks", +] diff --git a/src/ark/segmentation/ez_seg/composites.py b/src/ark/segmentation/ez_seg/composites.py new file mode 100644 index 000000000..04dc64a84 --- /dev/null +++ b/src/ark/segmentation/ez_seg/composites.py @@ -0,0 +1,177 @@ +import pathlib +from typing import List, Union +import numpy as np +import xarray as xr +from alpineer import misc_utils, image_utils, load_utils +from ark.segmentation.ez_seg.ez_seg_utils import log_creator + + +def composite_builder( + image_data_dir: Union[str, pathlib.Path], + img_sub_folder: str, + fov_list: list[str], + images_to_add: list[str], + images_to_subtract: list[str], + image_type: str, + composite_method: str, + composite_directory: Union[str, pathlib.Path], + composite_name: str, + log_dir: Union[str, pathlib.Path], +) -> None: + """ + Adds tiffs together, either pixel clusters or base signal tiffs and returns a composite channel or mask. + + Args: + image_data_dir (Union[str, pathlib.Path]): The path to dir containing the set of all images + which get filtered out with `images_to_add` and `images_to_subtract`. + img_sub_folder (str): A name for sub-folders within each fov in the image_data location. + fov_list: A list of fov's to create composite channels through. + images_to_add (List[str]): A list of channels or pixel cluster names to add together. + images_to_subtract (List[str]): A list of channels or pixel cluster names to subtract + from the composite. + image_type (str): Either "signal" or "pixel_cluster" data. + composite_method (str): Binarized mask returns ("binary") or intensity, gray-scale tiffs + returned ("total"). + composite_directory (Union[str, pathlib.Path]): The directory to save the composite array. + composite_name (str): The name of the composite array to save. + log_dir: The directory to save log information to. + + Returns: + np.ndarray: Saves the composite array, either as a binary mask, or as a scaled intensity array. + """ + for fov in fov_list: + # load in tiff images and verify channels are present + fov_data = load_utils.load_imgs_from_tree( + data_dir=image_data_dir, img_sub_folder=img_sub_folder, fovs=fov + ) + + image_shape = fov_data.shape[1:3] + + misc_utils.verify_in_list( + images_to_add=images_to_add, image_names=fov_data.channels.values + ) + misc_utils.verify_in_list( + images_to_subtract=images_to_subtract, image_names=fov_data.channels.values + ) + misc_utils.verify_in_list( + composite_method=composite_method, options=["binary", "total"] + ) + + # make composite dir if not there already + if isinstance(composite_directory, str): + composite_directory = pathlib.Path(composite_directory) + composite_directory.mkdir(parents=True, exist_ok=True) + + # Initialize composite array, and add & subtract channels + composite_array = np.zeros(shape=image_shape) + if images_to_add: + composite_array = add_to_composite( + fov_data, composite_array, images_to_add, image_type, composite_method + ) + if images_to_subtract: + composite_array = subtract_from_composite( + fov_data, composite_array, images_to_subtract, image_type, composite_method + ) + + # Create the fov dir within the composite dir + composite_fov_dir = composite_directory / fov + composite_fov_dir.mkdir(parents=True, exist_ok=True) + + # Save the composite image + image_utils.save_image( + fname=composite_directory / fov / f"{composite_name}.tiff", + data=composite_array.astype(np.uint32) + ) + + # Write a log saving composite builder info + variables_to_log = { + "image_data_dir": image_data_dir, + "fov_list": fov_list, + "images_to_add": images_to_add, + "images_to_subtract": images_to_subtract, + "image_type": image_type, + "composite_method": composite_method, + "composite_directory": composite_directory, + "composite_name": composite_name, + } + log_creator(variables_to_log, log_dir, f"{composite_name}_composite_log.txt") + + print("Composites built and saved") + + +def add_to_composite( + data: xr.DataArray, + composite_array: np.ndarray, + images_to_add: List[str], + image_type: str, + composite_method: str, +) -> np.ndarray: + """ + Adds tiffs together to form a composite array. + + Args: + data (xr.DataArray): The data array containing the set of all images which get filtered out + with `images_to_add`. + composite_array (np.ndarray): The array to add channels to. + images_to_add (List[str]): A list of channels or pixel cluster names to add together. + image_type (str): Either "signal" or "pixel_cluster" data. + composite_method (str): Binarized mask returns ("binary") or intensity, gray-scale tiffs + returned ("total"). + + Returns: + np.ndarray: The composite array, either as a binary mask, or as a scaled intensity array. + """ + + filtered_channels: xr.DataArray = data.sel( + {"channels": images_to_add}).squeeze().astype(np.int32) + if len(images_to_add) > 1: + composite_array: np.ndarray = filtered_channels.sum(dim="channels").values + else: + composite_array: np.ndarray = filtered_channels + if image_type == "pixel_cluster" or composite_method == "binary": + composite_array = composite_array.clip(min=None, max=1) + + return composite_array + + +def subtract_from_composite( + data: xr.DataArray, + composite_array: np.ndarray, + images_to_subtract: List[str], + image_type: str, + composite_method: str, +) -> np.ndarray: + """ + Subtracts tiffs from a composite array. + + Args: + data (xr.DataArray): The data array containing the set of all images which get + filtered out with `images_to_subtract`. + composite_array (np.ndarray): An array to subtract channels from. + images_to_subtract (List[str]): A list of channels or pixel cluster names to subtract + from the composite. + image_type (str): Either "signal" or "pixel_cluster" data. + composite_method (str): Binarized mask returns ('binary') or intensity, gray-scale tiffs + returned ('total'). + + Returns: + np.ndarray: The composite array, either as a binary mask, or as a scaled intensity array. + """ + + filtered_channels: xr.DataArray = data.sel( + {"channels": images_to_subtract}).squeeze().astype(np.int32) + if len(images_to_subtract) > 1: + composite_array2sub: np.ndarray = filtered_channels.sum(dim="channels").values + else: + composite_array2sub: np.ndarray = filtered_channels + + if image_type == "signal" and composite_method == "binary": + mask_2_zero = composite_array2sub > 0 + composite_array[mask_2_zero] = 0 + composite_array[composite_array > 1] = 1 + + else: + composite_array -= composite_array2sub + composite_array = composite_array.clip(min=0, max=None) + + return composite_array diff --git a/src/ark/segmentation/ez_seg/ez_object_segmentation.py b/src/ark/segmentation/ez_seg/ez_object_segmentation.py new file mode 100644 index 000000000..d7b4d7982 --- /dev/null +++ b/src/ark/segmentation/ez_seg/ez_object_segmentation.py @@ -0,0 +1,292 @@ +import pathlib +from typing import Optional, Tuple, Union, Literal +import numpy as np +from skimage import measure, filters, morphology +from skimage.util import map_array +import pandas as pd +from alpineer import misc_utils, load_utils, image_utils, io_utils +from ark.segmentation.ez_seg.ez_seg_utils import log_creator +import xarray as xr +import warnings + + +def create_object_masks( + image_data_dir: Union[str, pathlib.Path], + img_sub_folder: Optional[str], + fov_list: list[str], + mask_name: str, + channel_to_segment: str, + masks_dir: Union[str, pathlib.Path], + log_dir: Union[str, pathlib.Path], + object_shape_type: str = "blob", + sigma: int = 1, + thresh: Optional[np.float32] = None, + hole_size: Optional[int] = None, + fov_dim: int = 400, + min_object_area: int = 100, + max_object_area: int = 100000, +) -> None: + """ + Calculates a mask for each channel in the FOV for circular or 'blob'-like objects such as: single large cells or amyloid + plaques. It will blur the input image, then threshold the blurred image on either a given + fixed value, or an adaptive thresholding method. In addition, it removes small holes using + that same thresholding input and filters out objects which are either too small or too large. + + Args: + image_data_dir (Union[str, pathlib.Path]): The directory to pull images from to perform segmentation on. + img_sub_folder (str): A name for sub-folders within each fov in the image_data location. + fov_list: A list of fov names to segment on. + mask_name (str): The name of the masks you are creating. + channel_to_segment: The channel on which to perform segmentation. + masks_dir (Union[str, pathlib.Path]): The directory to save segmented images to. + object_shape_type (str, optional): Specify whether the object is either "blob" or + "projection" shaped. Defaults to "blob". + sigma (int): The standard deviation for Gaussian kernel, used for blurring the + image. Defaults to 1. + thresh (np.float32, optional): The global threshold value for image thresholding if + desired. Defaults to None. + hole_size (int, optional): A specific area to close small holes over in object masks. + Defaults to None. + fov_dim (int): The dimension in μm of the FOV. + min_object_area (int): The minimum size (area) of an object to capture in + pixels. Defaults to 100. + max_object_area (int): The maximum size (area) of an object to capture in + pixels. Defaults to 100000. + log_dir (Union[str, pathlib.Path]): The directory to save log information to. + """ + + # Input validation + io_utils.validate_paths([image_data_dir, masks_dir, log_dir]) + + misc_utils.verify_in_list( + object_shape=[object_shape_type], object_shape_options=["blob", "projection"] + ) + + for fov in fov_list: + fov_xr: xr.DataArray = load_utils.load_imgs_from_tree( + data_dir=image_data_dir, img_sub_folder=img_sub_folder, fovs=fov + ).squeeze() + + # handles folders where only 1 channel is loaded, often for composites + try: + len(fov_xr.channels) + channel: xr.DataArray = fov_xr.sel({"channels": channel_to_segment}).astype( + int + ) + except TypeError: + channel: xr.DataArray = fov_xr.astype(int) + + object_masks: np.ndarray = _create_object_mask( + input_image=channel, + object_shape_type=object_shape_type, + sigma=sigma, + thresh=thresh, + hole_size=hole_size, + fov_dim=fov_dim, + min_object_area=min_object_area, + max_object_area=max_object_area, + ) + + # save the channel overlay + save_path = pathlib.Path(masks_dir) / f"{fov}_{mask_name}.tiff" + image_utils.save_image(fname=save_path, data=object_masks) + + # Write a log saving ez segment info + variables_to_log = { + "image_data_dir": image_data_dir, + "fov_list": fov_list, + "mask_name": mask_name, + "channel_to_segment": channel_to_segment, + "masks_dir": masks_dir, + "object_shape_type": object_shape_type, + "sigma": sigma, + "thresh": thresh, + "hole_size": hole_size, + "fov_dim": fov_dim, + "min_object_area": min_object_area, + "max_object_area": max_object_area, + } + log_creator(variables_to_log, log_dir, f"{mask_name}_segmentation_log.txt") + print("ez masks built and saved") + + +def _create_object_mask( + input_image: xr.DataArray, + object_shape_type: Union[Literal["blob"], Literal["projection"]] = "blob", + sigma: int = 1, + thresh: Union[int, Literal["auto"]] = None, + hole_size: Union[int, Literal["auto"]] = "auto", + fov_dim: int = 400, + min_object_area: int = 10, + max_object_area: int = 100000, +) -> np.ndarray: + """ + Calculates a mask for circular or 'blob'-like objects such as: single large cells or amyloid + plaques. It will blur the input image, then threshold the blurred image on either a given + fixed value, or an adaptive thresholding method. In addition, it removes small holes using + that same thresholding input and filters out objects which are either too small or too large. + + Args: + input_image (xr.DataArray): The numpy array (image) to perform segmentation on. + object_shape_type (str, optional): Specify whether the object is either "blob" or + "projection" shaped. Defaults to "blob". + sigma (int): The standard deviation for Gaussian kernel, used for blurring the + image. Defaults to 1. + thresh (int, str, optional): The global threshold value for image thresholding if + desired. Defaults to "auto". + hole_size (int, str, optional): A specific area to close small holes over in object masks. + Defaults to None. + fov_dim (int): The dimension in μm of the FOV. + min_object_area (int): The minimum size (area) of an object to capture in + pixels. Defaults to 100. + max_object_area (int): The maximum size (area) of an object to capture in + pixels. Defaults to 100000. + + Returns: + np.ndarray: The object mask. + """ + + # Do not display any UserWarning msg's about boolean arrays here. + warnings.filterwarnings( + "ignore", message="Any labeled images will be returned as a boolean array. Did you mean to use a boolean array?") + + # Input validation + misc_utils.verify_in_list(object_shape_type=[object_shape_type], object_shape_options=["blob", "projection"]) + + # Copy the input image, and get its shape + img2mask: np.ndarray = input_image.copy().to_numpy() + img2mask_shape: Tuple[int, int] = img2mask.shape + + # Blur the input mask using given sigma value + if sigma is None: + img2mask_blur = img2mask + else: + img2mask_blur: np.ndarray = filters.gaussian( + img2mask, sigma=sigma, preserve_range=True + ) + + # Apply binary thresholding to the blurred image + if isinstance(thresh, int): + # Find the threshold value based on the given percentile number + img_nonzero = img2mask_blur[img2mask_blur != 0] + thresh_percentile = np.percentile(img_nonzero, thresh) + + # Threshold the array where values below the threshold are set to 0 + img2mask_thresh = np.where(img2mask_blur < thresh_percentile, 0, img2mask_blur) + + elif thresh == "auto": + local_thresh_block_size: int = get_block_size( + block_type="local_thresh", fov_dim=fov_dim, img_shape=img2mask_shape[0] + ) + img2mask_thresh: np.ndarray = img2mask_blur > filters.threshold_local( + img2mask_blur, block_size=local_thresh_block_size + ) + elif thresh is None: + img2mask_thresh = img2mask_blur + else: + raise ValueError( + f"Invalid `threshold` value: {thresh}. Must be either `auto`, `None` or an integer." + ) + + # Binarize the image in prep for removing holes. + img2mask_thresh_binary = img2mask_thresh > 0 + img2mask_thresh[img2mask_thresh_binary] = 1 + img2mask_thresh = img2mask_thresh.astype(int) + + # Remove small holes within the objects + if isinstance(hole_size, int): + img2mask_rm_holes: np.ndarray = morphology.remove_small_holes( + img2mask_thresh, area_threshold=hole_size + ) + elif hole_size == "auto": + small_holes_block_size: int = get_block_size( + block_type="small_holes", fov_dim=fov_dim, img_shape=img2mask_shape[0] + ) + img2mask_rm_holes: np.ndarray = morphology.remove_small_holes( + img2mask_thresh, area_threshold=small_holes_block_size + ) + elif hole_size is None: + img2mask_rm_holes = img2mask_thresh + else: + raise ValueError( + f"Invalid `hole_size` value: {hole_size}. Must be either `auto`, `None` or an integer." + ) + + # Filter projections + if object_shape_type == "projection": + img2mask_filtered: np.ndarray = filters.meijering( + img2mask_rm_holes, sigmas=range(1, 5, 1), black_ridges=False + ) + else: + img2mask_filtered: np.ndarray = img2mask_rm_holes + + # Binarize the image in prep for labeling. + img2mask_filtered_binary = img2mask_filtered > 0 + img2mask_filtered[img2mask_filtered_binary] = 1 + + # Extract `label` and `area` from regionprops + labeled_object_masks = measure.label(img2mask_filtered, connectivity=2) + + # Convert dictionary of region properties to DataFrame + object_masks_df: pd.DataFrame = pd.DataFrame( + measure.regionprops_table( + label_image=labeled_object_masks, + cache=True, + properties=[ + "label", + "area", + ], + ) + ) + + # zero out objects not meeting size requirements + keep_labels_bool = (object_masks_df["area"] >= min_object_area) & ( + object_masks_df["area"] <= max_object_area + ) + all_labels = object_masks_df["label"] + labels_to_keep = all_labels * keep_labels_bool + + # Map filtered objects to the object mask + objects_filtered_by_size = map_array( + labeled_object_masks, all_labels.to_numpy(), labels_to_keep.to_numpy() + ) + + return objects_filtered_by_size.astype(np.int32) + + +def get_block_size(block_type: str, fov_dim: int, img_shape: int) -> int: + """ + Computes the approximate local otsu threshold based on fov size (in μm) and pixel resolution. + + Args: + block_type (str): Either "small_holes" or "local_thresh" + fov_dim (int, optional): The size in μm for the FOV. + img_shape (int, optional): The shape of the image. + + Returns: + int: Returns the approximate block area + """ + + # Input validation + misc_utils.verify_in_list( + block_type=[block_type], block_types=["small_holes", "local_thresh"] + ) + # Get the size of the pixel + + pixel_size = fov_dim / img_shape + # grab block size for removing small holes + if block_type == "small_holes": + size = (np.pi * 5) ** 2 / pixel_size + # round up above value to make it into an integer + area: int = round(size) + # grab local threshold block size + else: + # use this to calculate out how many pixels it takes to get to roughly 10 μm + # (roughly a cell soma diameter) + size: float = 10 / pixel_size + + # round the area up to the nearest odd number + area: int = round(size) + if area % 2 == 0: + area += 1 + return area diff --git a/src/ark/segmentation/ez_seg/ez_seg_display.py b/src/ark/segmentation/ez_seg/ez_seg_display.py new file mode 100644 index 000000000..405577aff --- /dev/null +++ b/src/ark/segmentation/ez_seg/ez_seg_display.py @@ -0,0 +1,224 @@ +import pathlib +from typing import Union +from matplotlib.axes import Axes +from skimage.io import imread +from skimage import feature, color, filters +from skimage.util import img_as_ubyte +import numpy as np +import matplotlib.pyplot as plt +import os +from matplotlib.figure import Figure +from matplotlib import gridspec +from alpineer import io_utils + + +def display_channel_image( + base_image_path: Union[str, pathlib.Path], + sub_folder_name: str, + test_fov_name: str, + channel_name: str, + composite: bool = False, +) -> None: + """ + Displays a channel or a composite image. + + Args: + base_image_path (Union[str, pathlib.Path]): The path to the image. + sub_folder_name (str): If a subfolder name for the channel data exists. + test_fov_name (str): The name of the fov you wish to display. + channel_name (str): The name of the channel you wish to display. + composite (bool): Whether the image to be viewed is a composite image. + """ + # Show test composite image + if composite or (sub_folder_name is None): + sub_folder_name = "" + + image_path = ( + pathlib.Path(base_image_path) + / test_fov_name + / sub_folder_name + / f"{channel_name}.tiff" + ) + + if isinstance(image_path, str): + image_path = pathlib.Path(image_path) + io_utils.validate_paths(image_path) + + base_image: np.ndarray = imread(image_path, as_gray=True) + + base_image_scaled = img_as_ubyte(base_image) + + # Plot + fig: Figure = plt.figure(dpi=300, figsize=(6, 6)) + fig.set_layout_engine(layout="constrained") + gs = gridspec.GridSpec(1, 1, figure=fig) + fig.suptitle(f"{image_path.name}") + + ax: Axes = fig.add_subplot(gs[0, 0]) + ax.imshow(base_image_scaled) + ax.axis("off") + + +# for displaying segmentation masks overlaid upon a base channel or composite +def overlay_mask_outlines( + fov: str, + channel: str, + image_dir: Union[str, os.PathLike], + sub_folder_name: str, + mask_name: str, + mask_dir: Union[str, os.PathLike], +) -> None: + """ + Displays a segmentation mask overlaid on a base image (channel or composite). + + Args: + fov (str): name of fov to be viewed + channel (str): name of channel to view + image_dir (Union[str, os.PathLike]): The Path to channel for viewing. + sub_folder_name (str): If a subfolder name for the channel data exists. + mask_name (str): The name of mask to view + mask_dir (Union[str, os.PathLike]): The path to the directory containing the mask. + """ + if sub_folder_name is None: + sub_folder_name = "" + + if isinstance(image_dir, str): + image_dir = pathlib.Path(image_dir) + if isinstance(mask_dir, str): + mask_dir = pathlib.Path(mask_dir) + + image_dir = image_dir / sub_folder_name + + io_utils.validate_paths([image_dir, mask_dir]) + + # Get ezseg and channel image paths + channel_image_path = pathlib.Path(image_dir) / fov / f"{channel}.tiff" + mask_image_path = pathlib.Path(mask_dir) / f"{fov}_{mask_name}.tiff" + + # Validate paths + io_utils.validate_paths(paths=[channel_image_path, mask_image_path]) + + # Load the base image and mask image + # Autoscale the base image + channel_image: np.ndarray = imread(channel_image_path, as_gray=True) + mask_image: np.ndarray = imread(mask_image_path, as_gray=True) + + # Auto-scale the base image + channel_image_scaled = img_as_ubyte(channel_image) + + # Apply Canny edge detection to extract outlines + edges: np.ndarray = feature.canny( + image=mask_image, low_threshold=0, high_threshold=1 + ) + + # Set the outline color to red + outline_color = (255, 0, 0) + + # Convert the base image to RGB + rgb_channel_image_scaled = color.gray2rgb(channel_image_scaled) + + # Overlay the outlines on the copy of the base image + rgb_channel_image_scaled[edges != 0] = outline_color + + # Create a new figure + fig: Figure = plt.figure(dpi=300, figsize=(6, 6)) + fig.set_layout_engine(layout="constrained") + gs = gridspec.GridSpec(1, 1, figure=fig) + fig.suptitle(f"Mask: {mask_name}") + ax: Axes = fig.add_subplot(gs[0, 0]) + ax.imshow(channel_image) + # Display color mask with transparency + ax.imshow(rgb_channel_image_scaled, alpha=0.3) + ax.axis("off") + + +def multiple_mask_display( + fov: str, + mask_name: str, + object_mask_dir: Union[str, os.PathLike], + cell_mask_dir: Union[str, os.PathLike], + merged_mask_dir: Union[str, os.PathLike], +) -> None: + """ + Create a grid to display the object, cell, and merged masks for a given fov. + + Args: + fov (str): Name of the fov to view + mask_name (str): Name of mask to view + object_mask_dir (Union[str, os.PathLike]): Directory where the object masks are stored. + cell_mask_dir (Union[str, os.PathLike]): Directory where the cell masks are stored. + merged_mask_dir (Union[str, os.PathLike]): Directory where the merged masks are stored. + """ + if isinstance(object_mask_dir, str): + object_mask_dir = pathlib.Path(object_mask_dir) + if isinstance(cell_mask_dir, str): + cell_mask_dir = pathlib.Path(cell_mask_dir) + if isinstance(merged_mask_dir, str): + merged_mask_dir = pathlib.Path(merged_mask_dir) + io_utils.validate_paths([object_mask_dir, cell_mask_dir, merged_mask_dir]) + + modified_overlay_mask: np.ndarray = create_overlap_and_merge_visual( + fov, mask_name, object_mask_dir, cell_mask_dir, merged_mask_dir + ) + + # Create a new figure + fig: Figure = plt.figure(dpi=300, figsize=(6, 6)) + fig.set_layout_engine(layout="constrained") + gs = gridspec.GridSpec(1, 1, figure=fig) + fig.suptitle(f"Merged Mask: {mask_name}") + ax: Axes = fig.add_subplot(gs[0, 0]) + # Display color mask with transparency + ax.imshow(modified_overlay_mask) + ax.axis("off") + + +def create_overlap_and_merge_visual( + fov: str, + mask_name: str, + object_mask_dir: pathlib.Path, + cell_mask_dir: pathlib.Path, + merged_mask_dir: pathlib.Path, +) -> np.ndarray: + """ + Generate the NumPy Array representing the overlap between two masks + + Args: + fov (str): Name of the fov to view + mask_name (str): Name of mask to view + object_mask_dir (pathlib.Path): Directory where the object masks are stored. + cell_mask_dir (pathlib.Path): Directory where the cell masks are stored. + merged_mask_dir (pathlib.Path): Directory where the merged masks are stored. + + Returns: + np.ndarray: + Contains an overlap image of the two masks + """ + # read in masks + object_mask: np.ndarray = imread(object_mask_dir / f"{fov}_{mask_name}.tiff") + cell_mask: np.ndarray = imread( + cell_mask_dir / f"{fov}_whole_cell.tiff", as_gray=True + ) + merged_mask: np.ndarray = imread( + merged_mask_dir / f"{fov}_{mask_name}_merged.tiff", as_gray=True + ) + + # Assign colors to the non-overlapping areas of each mask + # Object masks in red + red_array = np.zeros(shape=object_mask.shape, dtype=np.uint8) + red_array[object_mask > 0] = 225 + + # Cell masks in blue + blue_array = np.zeros(shape=object_mask.shape, dtype=np.uint8) + blue_array[cell_mask > 0] = 255 + + # Merged mask edges in green + merge_bool = merged_mask > 0 + edges = filters.sobel(merge_bool) + green_array = np.zeros(shape=object_mask.shape, dtype=np.uint8) + green_array[edges > 0] = 255 + + # Combine red, green, and blue channels to create the final image + image = np.stack([red_array, green_array, blue_array], axis=-1) + + # return this image to the multi_merge_mask_display function. + return image diff --git a/src/ark/segmentation/ez_seg/ez_seg_utils.py b/src/ark/segmentation/ez_seg/ez_seg_utils.py new file mode 100644 index 000000000..d0fc3510f --- /dev/null +++ b/src/ark/segmentation/ez_seg/ez_seg_utils.py @@ -0,0 +1,184 @@ +from typing import Generator, List, Union +from skimage.io import imread +from alpineer.image_utils import save_image +from alpineer import io_utils +import os +import re +import shutil +from tqdm.auto import tqdm +import numpy as np +import pathlib +import pandas as pd + + +def find_and_copy_files(mask_names: List[str], source_folder: Union[str, pathlib.Path], + destination_folder: Union[str, pathlib.Path]): + """ + Creates a new directory of masks for relabeling and cell table generation. Useful if more than + one mask type is needed for cell table generation. E.g. merged cells and proteopathy objects. + + Args: + mask_names (List[str]): + List of mask names to be merged. Can be partial names. + source_folder (Union[str, pathlib.Path]): + The parent segmentation folder all masks are found in. + destination_folder (Union[str, pathlib.Path]): + New dir where final masks will be copied to. + """ + # Ensure the destination folder exists, create it if not + if not os.path.exists(destination_folder): + os.makedirs(destination_folder) + + # Iterate through each name in the list + for mn in mask_names: + # Compile a regex pattern to match files containing the name anywhere in the file name + pattern = re.compile(f".*{re.escape(mn)}.*", re.IGNORECASE) + + # Search for files associated with the current name in the source folder using regex + files_to_copy = [] + for root, dirs, files in os.walk(source_folder): + for file in files: + if pattern.match(file) and str(destination_folder) not in str(root): + files_to_copy.append(os.path.join(root, file)) + + # Copy the found files to the destination folder + for file_path in files_to_copy: + shutil.copy(file_path, os.path.join(destination_folder, os.path.basename(file_path))) + + +def renumber_masks( + mask_dir: Union[pathlib.Path, str] +): + """ + Relabels all masks in mask tiffs so each label is unique across all mask images + in entire dataset. + + Args: + mask_dir (Union[pathlib.Path, str]): Directory that points to parent directory of all + segmentation masks to be relabeled. + """ + mask_dir_path = pathlib.Path(mask_dir) + io_utils.validate_paths(mask_dir_path) + + all_images: Generator[pathlib.Path, None, None] = mask_dir_path.rglob("*.tiff") + + global_unique_labels = 1 + + # First pass - get total number of unique masks + for image in all_images: + img: np.ndarray = imread(image) + unique_labels: np.ndarray = np.unique(img) + non_zero_labels: np.ndarray = unique_labels[unique_labels != 0] + global_unique_labels += len(non_zero_labels) + + all_images: Generator[pathlib.Path, None, None] = mask_dir_path.rglob("*.tiff") + + # Second pass - relabel all masks starting at unique num of masks +1 + for image in all_images: + img: np.ndarray = imread(image) + unique_labels: np.ndarray = np.unique(img) + for label in unique_labels: + if label != 0: + img[img == label] = global_unique_labels + global_unique_labels += 1 + save_image(fname=image, data=img) + print("Relabeling Complete.") + + +def create_mantis_project( + fovs: Union[str, List[str]], + image_data_dir: Union[str, pathlib.Path], + segmentation_dir: Union[str, pathlib.Path], + mantis_dir: Union[str, pathlib.Path], +) -> None: + """ + Creates a folder for viewing FOVs in Mantis. + + Args: + fovs (Union[str, List[str]]): + A list of FOVs to use for creating the mantis project + image_data_dir (Union[str, pathlib.Path]): + The path to the directory containing the raw image data. + segmentation_dir (Union[str, pathlib.Path]): + The path to the directory containing masks. + mantis_dir: + The path to the directory containing housing the ez_seg specific mantis project. + """ + for fov in tqdm(io_utils.list_folders(image_data_dir, substrs=fovs)): + shutil.copytree(os.path.join(image_data_dir, fov), dst=os.path.join(mantis_dir, fov)) + + for seg_type in io_utils.list_folders(segmentation_dir): + for mask in io_utils.list_files(os.path.join(segmentation_dir, seg_type), substrs=fov): + shutil.copy(os.path.join(segmentation_dir, seg_type, mask), + dst=os.path.join(mantis_dir, fov) + ) + + +def log_creator(variables_to_log: dict, base_dir: str, log_name: str = "config_values.txt"): + """Logs the variables in `variables_to_log` to the file at `base_dir/log_name` + + Args: + variables_to_log (dict): + The name of each variable along with their associated value + base_dir (str): + Where the log will be written to + log_name (str): + The name of the log file to write the variables to + """ + # Define the filename for the text file + output_file = os.path.join(base_dir, log_name) + + # Open the file in write mode and write the variable values + with open(output_file, "w") as file: + for variable_name, variable_value in variables_to_log.items(): + file.write(f"{variable_name}: {variable_value}\n") + + print(f"Values saved to {output_file}") + + +def filter_csvs_by_mask(csv_path_name: Union[str, pathlib.Path], csv_substr_replace: str, + column_to_filter: str = "mask_type") -> None: + """Function to take in and separate a single cell table into multiple + based on the mask_type parameter. + + Args: + csv_path_name (Union[str, pathlib.Path]): + The path to the directory containing the cell table CSVs. + csv_substr_replace (str): + The substring in the CSV file name to replace in favor of the mask name + column_to_filter (str): + The name of the column to split on, defaults to `"mask_type"` + """ + # Load the CSV file as a DataFrame (replace 'input.csv' with your CSV file) + csv_files = io_utils.list_files(csv_path_name, substrs=".csv") + for item in csv_files: + if csv_substr_replace not in item: + continue + + input_csv_file = os.path.join(csv_path_name, item) + df = pd.read_csv(input_csv_file) + + # Get unique values from the specified column + filter_values = df[column_to_filter].unique() + + # Create a dictionary to store filtered DataFrames + filtered_dfs = {} + + # Filter the DataFrame for each unique value and save as separate CSV files + for filter_value in filter_values: + filtered_df = df[df[column_to_filter] == filter_value] + + # Define the output CSV file name based on the filtered value + table_type_str = item.replace(csv_substr_replace, '') + output_csv_file = os.path.join( + csv_path_name, ''.join([f'filtered_{filter_value}', table_type_str]) + ) + + # Save the filtered DataFrame to a new CSV file + filtered_df.to_csv(output_csv_file, index=False) + + # Store the filtered DataFrame in the dictionary + filtered_dfs[filter_value] = filtered_df + + # Print msg + print("Filtering of csv's complete.") diff --git a/src/ark/segmentation/ez_seg/merge_masks.py b/src/ark/segmentation/ez_seg/merge_masks.py new file mode 100644 index 000000000..5651cb088 --- /dev/null +++ b/src/ark/segmentation/ez_seg/merge_masks.py @@ -0,0 +1,164 @@ +import pathlib +from typing import List, Union +import xarray as xr +import numpy as np +import os +from skimage.io import imread +from scipy.ndimage import label +from alpineer import load_utils, image_utils +from ark.segmentation.ez_seg.ez_seg_utils import log_creator + + +def merge_masks_seq( + fov_list: List[str], + object_list: List[str], + object_mask_dir: Union[pathlib.Path, str], + cell_mask_path: Union[pathlib.Path, str], + overlap_percent_threshold: int, + save_path: Union[pathlib.Path, str], + log_dir: Union[pathlib.Path, str] +) -> None: + """ + Sequentially merge object masks with cell masks. Object list is ordered enforced, e.g. object_list[i] will merge + overlapping object masks with cell masks from the initial cell segmentation. Remaining, un-merged cell masks will + then be used to merge with object_list[i+1], etc. + + Args: + fov_list (List[str]): A list of fov names to merge masks over. + object_list (List[str]): A list of names representing previously generated object masks. Note, order matters. + object_mask_dir (Union[pathlib.Path, str]): Directory where object (ez) segmented masks are located + cell_mask_path (Union[str, pathlib.Path]): Path to where the original cell masks are located. + overlap_percent_threshold (int): Percent overlap of total pixel area needed fo object to be merged to a cell. + save_path (Union[str, pathlib.Path]): The directory where merged masks and remaining cell mask will be saved. + log_dir (Union[str, pathlib.Path]): The directory to save log information to. + """ + # validate paths + if isinstance(object_mask_dir, str): + object_mask_dir = pathlib.Path(object_mask_dir) + if isinstance(cell_mask_path, str): + cell_mask_path = pathlib.Path(cell_mask_path) + if isinstance(save_path, str): + save_path = pathlib.Path(save_path) + + # for each fov, import cell and object masks (multiple mask types into single xr.DataArray) + for fov in fov_list: + curr_cell_mask = imread(fname=os.path.join( + cell_mask_path, '_'.join([f'{fov}', 'whole_cell.tiff'])) + ) + + fov_object_names = [f'{fov}_' + obj + '.tiff' for obj in object_list] + + objects: xr.DataArray = load_utils.load_imgs_from_dir( + object_mask_dir, files=fov_object_names).drop_vars("compartments").squeeze() + + # sort the imported objects w.r.t the object_list + objects.reindex(indexers={ + "fovs": fov_object_names + }) + + # for each object type in the fov, merge with cell masks + for obj in fov_object_names: + curr_object_mask = imread(fname=(object_mask_dir / obj)) + remaining_cells = merge_masks_single( + object_mask=curr_object_mask, + cell_mask=curr_cell_mask, + overlap_thresh=overlap_percent_threshold, + object_name=obj, + mask_save_path=save_path, + ) + curr_cell_mask = remaining_cells + + # save the unmerged cells as a tiff. + image_utils.save_image(fname=save_path / (fov + "_final_cells_remaining.tiff"), data=curr_cell_mask.astype(np.int32)) + + # Write a log saving mask merging info + variables_to_log = { + "fov_list": fov_list, + "object_list": object_list, + "object_mask_dir": object_mask_dir, + "cell_mask_path": cell_mask_path, + "overlap_percent_threshold": overlap_percent_threshold, + "save_path": save_path + } + log_creator(variables_to_log, log_dir, "mask_merge_log.txt") + print("Merged masks built and saved") + + +def merge_masks_single( + object_mask: np.ndarray, + cell_mask: np.ndarray, + overlap_thresh: int, + object_name: str, + mask_save_path: str, +) -> np.ndarray: + """ + Combines overlapping object and cell masks. For any combination which represents has at least `overlap` percentage + of overlap, the combined mask is kept and incorporated into the original object masks to generate a new set of masks. + + Args: + object_mask (np.ndarray): The object mask numpy array. + cell_mask (np.ndarray): The cell mask numpy array. + overlap_thresh (int): The percentage overlap required for a cell to be merged. + object_name (str): The name of the object. + mask_save_path (str): The path to save the mask. + + Returns: + np.ndarray: The cells remaining mask, which will be used for the next cycle in merging while there are objects. + When no more cells and objects are left to merge, the final, non-merged cells are returned. + """ + + if cell_mask.shape != object_mask.shape: + raise ValueError("Both masks must have the same shape") + + # Relabel cell, object masks + cell_labels, num_cell_labels = label(cell_mask) + object_labels, num_object_labels = label(object_mask) + + # Instantiate new array for merging + merged_mask = object_labels.copy() + + # Set up list to store merged cell labels + remove_cells_list = [0] + + # Find connected components in object and cell masks. Merge only those with highest overlap that meets threshold. + for obj_label in range(1, num_object_labels + 1): + # Extract a connected component from object_mask + object_mask_component = object_labels == obj_label + + best_overlap = 0 + best_cell_mask_component = None + cell_to_merge_label = None + + for cell_label in range(1, num_cell_labels + 1): + # Extract a connected component from cell_mask + cell_mask_component = cell_labels == cell_label + + # Calculate the overlap between cell_mask_component and object_mask_component + intersection = np.logical_and(cell_mask_component, object_mask_component) + overlap = intersection.sum() + + # Calculate cell-object overlap percent threshold + meets_overlap_thresh = overlap / cell_mask_component.sum() > overlap_thresh / 100 + + # Ensure cell overlap meets percent threshold and has the highest relative cell-object overlap + if overlap > best_overlap and meets_overlap_thresh: + best_overlap = overlap + best_cell_mask_component = cell_mask_component + cell_to_merge_label = cell_label + + # If best merge has been found, assign the merged cell+object into the new mask and record the cell label + if best_cell_mask_component is not None: + merged_mask[best_cell_mask_component == True] = obj_label + remove_cells_list.append(cell_to_merge_label) + + # Assign any unmerged cells into a remaining cell mask array. + non_merged_cell_mask = np.isin(cell_labels, remove_cells_list, invert=True) + cell_labels[non_merged_cell_mask == False] = 0 + + # Save the merged mask tiff. + image_utils.save_image( + fname=os.path.join(mask_save_path, object_name.removesuffix(".tiff") + "_merged.tiff"), + data=merged_mask) + + # Return unmerged cells + return cell_labels diff --git a/src/ark/segmentation/marker_quantification.py b/src/ark/segmentation/marker_quantification.py index 92e80f187..92d895094 100644 --- a/src/ark/segmentation/marker_quantification.py +++ b/src/ark/segmentation/marker_quantification.py @@ -1,5 +1,6 @@ import copy import warnings +from typing import List import numpy as np import pandas as pd @@ -124,8 +125,9 @@ def assign_single_compartment_features(marker_counts, compartment, cell_props, c # add counts of each marker to appropriate column # Only include the marker_count features up to the last filtered feature. - marker_counts.loc[compartment, cell_id, - marker_counts.features[1]:filtered_regionprops_names[-1]] = cell_features + marker_counts.loc[ + compartment, cell_id, marker_counts.features[1]:filtered_regionprops_names[-1] + ] = cell_features # add cell size to first column marker_counts.loc[compartment, cell_id, marker_counts.features[0]] = cell_coords.shape[0] @@ -399,7 +401,6 @@ def create_marker_count_matrices(segmentation_labels, image_data, nuclear_counts # define the FOV associated with this segmentation label fov = segmentation_labels.fovs.values[0] - print("extracting data from {}".format(fov)) # current mask label = segmentation_labels.loc[fov, :, :, :] @@ -528,48 +529,83 @@ def generate_cell_table(segmentation_dir, tiff_dir, img_sub_folder="TIFs", whole_cell_file = fov_name + '_whole_cell.tiff' nuclear_file = fov_name + '_nuclear.tiff' - # load the segmentation labels in - current_labels_cell = load_utils.load_imgs_from_dir(data_dir=segmentation_dir, - files=[whole_cell_file], - xr_dim_name='compartments', - xr_channel_names=['whole_cell'], - trim_suffix='_whole_cell') - - compartments = ['whole_cell'] - segmentation_labels = current_labels_cell.values + # for each label given in the argument. read in that mask for the fov, and proceed with label and table appending + mask_files = io_utils.list_files(segmentation_dir, substrs=fov_name) + mask_types = process_lists(fov_names=fovs, mask_names=mask_files) + + for mask_type in mask_types: + # load the segmentation labels in + fov_mask_name = fov_name + '_' + mask_type + ".tiff" + current_labels_cell = load_utils.load_imgs_from_dir(data_dir=segmentation_dir, + files=[fov_mask_name], + xr_dim_name='compartments', + xr_channel_names=[mask_type], + trim_suffix='_' + mask_type) + + compartments = ['whole_cell'] + segmentation_labels = current_labels_cell.values + + if nuclear_counts: + current_labels_nuc = load_utils.load_imgs_from_dir(data_dir=segmentation_dir, + files=[nuclear_file], + xr_dim_name='compartments', + xr_channel_names=['nuclear'], + trim_suffix='_nuclear') + compartments = ['whole_cell', 'nuclear'] + segmentation_labels = np.concatenate((current_labels_cell.values, + current_labels_nuc.values), axis=-1) + + current_labels = xr.DataArray(segmentation_labels, + coords=[current_labels_cell.fovs, + current_labels_cell.rows, + current_labels_cell.cols, + compartments], + dims=current_labels_cell.dims) + + # segment the imaging data + cell_table_size_normalized, cell_table_arcsinh_transformed = create_marker_count_matrices( + segmentation_labels=current_labels, + image_data=image_data, + extraction=extraction, + nuclear_counts=nuclear_counts, + fast_extraction=fast_extraction, + **kwargs + ) - if nuclear_counts: - current_labels_nuc = load_utils.load_imgs_from_dir(data_dir=segmentation_dir, - files=[nuclear_file], - xr_dim_name='compartments', - xr_channel_names=['nuclear'], - trim_suffix='_nuclear') - compartments = ['whole_cell', 'nuclear'] - segmentation_labels = np.concatenate((current_labels_cell.values, - current_labels_nuc.values), axis=-1) - - current_labels = xr.DataArray(segmentation_labels, - coords=[current_labels_cell.fovs, - current_labels_cell.rows, - current_labels_cell.cols, - compartments], - dims=current_labels_cell.dims) - - # segment the imaging data - cell_table_size_normalized, cell_table_arcsinh_transformed = create_marker_count_matrices( - segmentation_labels=current_labels, - image_data=image_data, - extraction=extraction, - nuclear_counts=nuclear_counts, - fast_extraction=fast_extraction, - **kwargs - ) + # add mask type column to the data frame + if mask_type == "final_cells_remaining": + mask_type_str = "whole_cell" + else: + mask_type_str = mask_type + cell_table_size_normalized['mask_type'] = mask_type_str + cell_table_arcsinh_transformed['mask_type'] = mask_type_str - normalized_tables.append(cell_table_size_normalized) - arcsinh_tables.append(cell_table_arcsinh_transformed) + # add to larger dataframe + normalized_tables.append(cell_table_size_normalized) + arcsinh_tables.append(cell_table_arcsinh_transformed) # now append to the final dfs to return combined_cell_table_size_normalized = pd.concat(normalized_tables) combined_cell_table_arcsinh_transformed = pd.concat(arcsinh_tables) return combined_cell_table_size_normalized, combined_cell_table_arcsinh_transformed + + +def process_lists(fov_names: List[str], mask_names: List[str]) -> List[str]: + """ + Function to strip prefixes from list: fov_names, strip '.tiff' suffix from list: mask names, + and remove underscore prefixes, returning unique mask values (i.e. categories of masks). + + Args: + fov_names (List[str]): list of fov names. Matching fov names in mask names will be returned without fov prefix. + mask_names (List[str]): list of mask names. Mask names will be returned without tif suffix. + + Returns: + List[str]: Unique mask names (i.e. categories of masks) + """ + stripped_mask_names = io_utils.remove_file_extensions(mask_names) + result = [itemB[len(prefix):] for itemB in stripped_mask_names for prefix in fov_names if itemB.startswith(prefix)] + # Remove underscore prefixes and return unique values + cleaned_result = [item.lstrip('_') for item in result] + unique_result = list(set(cleaned_result)) + return unique_result diff --git a/src/ark/utils/example_dataset.py b/src/ark/utils/example_dataset.py index 9cfb1b1f2..685450f2c 100644 --- a/src/ark/utils/example_dataset.py +++ b/src/ark/utils/example_dataset.py @@ -51,7 +51,8 @@ def __init__(self, dataset: str, overwrite_existing: bool = True, cache_dir: str "example_cell_output_dir": "pixie/example_cell_output_dir", "spatial_lda": "spatial_analysis/spatial_lda", "post_clustering": "post_clustering", - "ome_tiff": "ome_tiff" + "ome_tiff": "ome_tiff", + "ez_seg_data": "ez_seg_data" } """ Path suffixes for mapping each downloaded dataset partition to it's appropriate @@ -176,7 +177,8 @@ def get_example_dataset(dataset: str, save_dir: Union[str, pathlib.Path], "LDA_training_inference", "neighborhood_analysis", "pairwise_spatial_enrichment", - "ome_tiff"] + "ome_tiff", + "ez_seg_data"] # Check the appropriate dataset name try: diff --git a/start_jupyter.sh b/start_jupyter.sh old mode 100644 new mode 100755 index e0f3d1a03..4848b33c4 --- a/start_jupyter.sh +++ b/start_jupyter.sh @@ -1,7 +1,29 @@ -if [ $UPDATE_ARK -ne 0 ] +#!/usr/bin/env bash + +# check for template developer flag +JUPYTER_DIR='scripts' +update=0 + +while test $# -gt 0 +do + case "$1" in + -u|--update) + update=1 + shift + ;; + *) + echo "$1 is not an accepted option..." + echo "-u, --update : Update default scripts" + exit + ;; + esac +done + +if [ $update -ne 0 ] then - cd /opt/ark-analysis && python -m pip install . + bash update_notebooks.sh -u + else + bash update_notebooks.sh fi -cd /scripts -jupyter lab --ip=0.0.0.0 --allow-root --no-browser --port=$JUPYTER_PORT --notebook-dir=/$JUPYTER_DIR \ No newline at end of file +jupyter lab --notebook-dir $JUPYTER_DIR diff --git a/templates/1_Segment_Image_Data.ipynb b/templates/1_Segment_Image_Data.ipynb index 632692ca3..7cadf662a 100644 --- a/templates/1_Segment_Image_Data.ipynb +++ b/templates/1_Segment_Image_Data.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -32,7 +31,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -40,7 +38,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -64,7 +61,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -89,7 +85,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -155,7 +150,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -180,7 +174,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -228,7 +221,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -267,7 +259,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -321,7 +312,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -347,7 +337,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -417,7 +406,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.10.11" }, "vscode": { "interpreter": { diff --git a/templates/ez_segmenter.ipynb b/templates/ez_segmenter.ipynb new file mode 100644 index 000000000..380a7ee69 --- /dev/null +++ b/templates/ez_segmenter.ipynb @@ -0,0 +1,902 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ***ezSegmenter***" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --\n", + "#### * This segmentation tool enables creation of masks for objects not easily picked up by primary cell segmentation methods on multiplexed imaging data.\n", + "#### * In addition this tool can be used to create composites of channels as well as merge object masks with cell masks.\n", + "#### * Final Output : Image masks and cell+object tables.\n", + "-- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- --" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "import" + ] + }, + "outputs": [], + "source": [ + "import os\n", + "from ark.segmentation.ez_seg import (\n", + " ez_object_segmentation,\n", + " ez_seg_display,\n", + " merge_masks,\n", + " composites,\n", + " ez_seg_utils,\n", + ")\n", + "from alpineer import io_utils\n", + "from ark.utils import example_dataset\n", + "from ark.segmentation import marker_quantification" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 0: Set root directory and (Optional) download example dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we are using the example data located in `/data/example_dataset/input_data`. To modify this notebook to run using your own data, simply change `base_dir` to point to your own data directory.\n", + "\n", + "* `base_dir`: the path to all of your imaging data. This directory will contain all of the data generated by this notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "base_dir" + ] + }, + "outputs": [], + "source": [ + "# set up the base directory\n", + "base_dir = \"../data/example_dataset\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you would like to test the ez segmeneter features in Ark with an example dataset, run the cell below. Otherwise skip to section 1.\n", + "\n", + "The cell below will download a dataset consisting of 10 FOVs with 47 channels, along with corresponding composite microglia images. You may find more information about the example dataset in the [README](../README.md#example-dataset).\n", + "\n", + "If you are using your own data, skip the cell below.\n", + "\n", + "* `overwrite_existing`: If set to `False`, it will not overwrite existing data in the `data/example_dataset`. Recommended leaving as `True` if you are doing a clean run of the `ark` pipeline using this dataset from the start. If you already have the dataset downloaded, set to `False`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "ex_data_download" + ] + }, + "outputs": [], + "source": [ + "example_dataset.get_example_dataset(dataset=\"ez_seg_data\", save_dir = base_dir, overwrite_existing = True)\n", + "\n", + "# example data gets written to an ez_seg_data sub-folder\n", + "base_dir = os.path.join(base_dir, \"ez_seg_data\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1: Set file paths & Get image paths" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define directory paths\n", + "\n", + "- `image_data_dir`: Channel data directory - MUST already exist.\n", + "- `subfolder_name`: If your image data is nested (e.g. ..fov0/rescaled/*) then replace None with the name of the dir (e.g. rescaled)\n", + "- `segmentation_dir`: Segmentation directory. Either not created yet or filled with cell masks - may already exist.\n", + "- `cell_table_dir`: Directory to store cell + object tables - may already exist.\n", + "- `ez_visualization_dir`: Directory to store masks in a mantis viewer - friendly format - may already exist\n", + "- `composite_dir`: Directory to store composite images. Created by the notebook.\n", + "- `ez_masks_dir`: Sub-directory of segmentation directory that will store ez segmenter masks. Created by the notebook.\n", + "- `merged_masks_dir`: Directory to store merged cell + object masks. Created by the notebook.\n", + "- `log_dir`: Directory to store log info as .txt files. Created by the notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "file_path" + ] + }, + "outputs": [], + "source": [ + "image_data_dir = os.path.join(base_dir, \"image_data\")\n", + "sub_folder_name = None\n", + "segmentation_dir = os.path.join(base_dir, \"segmentation\")\n", + "cell_table_dir = os.path.join(base_dir, \"cell_table\")\n", + "ez_visualization_dir = os.path.join(base_dir, \"mantis_visualization\")\n", + "\n", + "# automatically created by the notebook\n", + "composite_dir = os.path.join(base_dir, \"composites\")\n", + "ez_masks_dir = os.path.join(segmentation_dir, \"ez_masks\")\n", + "merged_masks_dir = os.path.join(segmentation_dir, \"merged_masks_dir\")\n", + "log_dir = os.path.join(base_dir, \"logs\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-21T19:52:25.519740Z", + "iopub.status.busy": "2023-09-21T19:52:25.519471Z", + "iopub.status.idle": "2023-09-21T19:52:25.813044Z", + "shell.execute_reply": "2023-09-21T19:52:25.812481Z", + "shell.execute_reply.started": "2023-09-21T19:52:25.519720Z" + }, + "tags": [ + "create_dirs" + ] + }, + "outputs": [], + "source": [ + "# Create above directories if they do not exist\n", + "for directory in [\n", + " segmentation_dir,\n", + " cell_table_dir,\n", + " ez_visualization_dir,\n", + " composite_dir,\n", + " ez_masks_dir,\n", + " merged_masks_dir,\n", + " log_dir\n", + "]:\n", + " if not os.path.exists(directory):\n", + " os.makedirs(directory)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-21T19:52:26.240714Z", + "iopub.status.busy": "2023-09-21T19:52:26.240448Z", + "iopub.status.idle": "2023-09-21T19:52:26.279053Z", + "shell.execute_reply": "2023-09-21T19:52:26.278366Z", + "shell.execute_reply.started": "2023-09-21T19:52:26.240696Z" + }, + "tags": [ + "validate_path" + ] + }, + "outputs": [], + "source": [ + "# Validate paths of the directories.\n", + "io_utils.validate_paths(\n", + " [\n", + " image_data_dir,\n", + " segmentation_dir,\n", + " cell_table_dir,\n", + " ez_visualization_dir,\n", + " composite_dir,\n", + " ez_masks_dir,\n", + " merged_masks_dir,\n", + " log_dir\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compute and filter fov paths\n", + "\n", + "FOV names should be stored as folders within the channel data directory as outlined earlier" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-21T19:52:28.068136Z", + "iopub.status.busy": "2023-09-21T19:52:28.067866Z", + "iopub.status.idle": "2023-09-21T19:52:28.524234Z", + "shell.execute_reply": "2023-09-21T19:52:28.523636Z", + "shell.execute_reply.started": "2023-09-21T19:52:28.068119Z" + }, + "tags": [ + "load_fovs" + ] + }, + "outputs": [], + "source": [ + "# either get all fovs in the folder...\n", + "fovs = io_utils.list_folders(image_data_dir)\n", + "\n", + "# ... or optionally, select a specific set of fovs manually\n", + "# fovs = [\"fov0\", \"fov1\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Composite Builder (Optional)\n", + "\n", + "Here, you can combine channels to produce a single channel for later segmentation, through addition and / or subtraction of different channels." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Set composite values\n", + "\n", + "- Set the name in `composite_name`, and channels to combine in `to_add` and `to_subtract`\n", + "\n", + "Specify `image_type`:\n", + " - \"signal\" = intensity or count based images.\n", + " - \"pixel_clustered\" = individually labeled pixels by cluster label.\n", + "\n", + "Specify `composite_method`:\n", + " - `\"total\"` = return an image with summed values in each pixel.\n", + " - `\"binary\"` = return an image with either filled (1) or empty (0) values in each pixel." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-21T19:52:29.652939Z", + "iopub.status.busy": "2023-09-21T19:52:29.652669Z", + "iopub.status.idle": "2023-09-21T19:52:29.697565Z", + "shell.execute_reply": "2023-09-21T19:52:29.696818Z", + "shell.execute_reply.started": "2023-09-21T19:52:29.652920Z" + }, + "tags": [ + "composite_set" + ] + }, + "outputs": [], + "source": [ + "# What would you like to name your composite image\n", + "composite_name = \"amyloid\"\n", + "\n", + "# What channels would you like to add together?\n", + "to_add = [\"Amyloidbeta140\", \"Amyloidbeta142\", \"PanAmyloidbeta1724\"]\n", + "# What channels would you like to subtract?\n", + "to_subtract = [\"HistoneH3Lyo\", \"Background\"]\n", + "\n", + "# What image type do you want returned?\n", + "image_type = \"signal\"\n", + "# What combination method do you want to use?\n", + "composite_method = \"total\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create your composite channel\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-21T19:53:27.405633Z", + "iopub.status.busy": "2023-09-21T19:53:27.405333Z", + "iopub.status.idle": "2023-09-21T19:54:24.645416Z", + "shell.execute_reply": "2023-09-21T19:54:24.644657Z", + "shell.execute_reply.started": "2023-09-21T19:53:27.405606Z" + }, + "tags": [ + "composite_build" + ] + }, + "outputs": [], + "source": [ + "# Run composite builder\n", + "composites.composite_builder(\n", + " image_data_dir=image_data_dir,\n", + " img_sub_folder = sub_folder_name,\n", + " fov_list=fovs,\n", + " images_to_add=to_add,\n", + " images_to_subtract=to_subtract,\n", + " image_type=image_type,\n", + " composite_name=composite_name,\n", + " composite_directory=composite_dir,\n", + " composite_method=composite_method,\n", + " log_dir=log_dir\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### View a composite test image\n", + "- `fov_name`: Specify which FoV you'd like to see for visual testing purposes.\n", + "- `composite_name`: This should be the composite name you specified above" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2023-09-21T19:55:53.511155Z", + "iopub.status.busy": "2023-09-21T19:55:53.510882Z", + "iopub.status.idle": "2023-09-21T19:55:54.333815Z", + "shell.execute_reply": "2023-09-21T19:55:54.333254Z", + "shell.execute_reply.started": "2023-09-21T19:55:53.511136Z" + }, + "tags": [ + "display_composite" + ] + }, + "outputs": [], + "source": [ + "fov_name = \"fov0\"\n", + "composite_name = \"amyloid\"\n", + "\n", + "# Show test composite image\n", + "ez_seg_display.display_channel_image(composite_dir, sub_folder_name, fov_name, composite_name, composite=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### View a channel image\n", + "- `fov_name`: Specify which FoV you'd like to see for visual testing purposes.\n", + "- `channel_name`: This should be the channel name you wish to view." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "display_channel" + ] + }, + "outputs": [], + "source": [ + "fov_name = \"fov0\"\n", + "channel_name = \"Iba1\"\n", + "\n", + "# Show test composite image\n", + "ez_seg_display.display_channel_image(image_data_dir, sub_folder_name, fov_name, channel_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Create Object Masks" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create your object segmentation masks.\n", + "Here you will input which channel you would like as a base for segmenting single object masks. Return to this section to create masks for new object types for the same data.\n", + "\n", + "Additionally, set the following segmentation parameters below:\n", + "\n", + "###### Channel params\n", + "\n", + "- `channel_to_segment`: The name of the channel you wish to segment on.\n", + "- `channel_to_segment_path`: image_data_dir if segmenting a stand-alone channel, or composite_dir if segmenting a composite channel.\n", + "- `path_sub_folder_name`: Set to the subfolder name your channel to segment is located in, otherwise write `None`.\n", + "\n", + "###### Mask params\n", + "\n", + "- `mask_name`: The name you want to label these masks as, e.g. `\"plaques\"` or `\"microglia-projections\"`\n", + "- `object_shape`: The general shape of the object, can be either `\"blob\"` or `\"projection\"`\n", + "\n", + "###### Blur/threshold params\n", + "\n", + "- `blur`: The standard deviation for the Gaussian kernel to blur the image. Default is `1`\n", + "- `threshold`: The per-fov-percentile threshold value (integer) for image thresholding if desired. Set as `\"auto\"` to determine threshold locally in each image. Write `None` to set no threshold.\n", + "- `hole_size`: For any area smaller than `hole_size` those holes are closed. Otherwise leave as `\"auto\"` (or `None`) to determine the hole size based on image dimensions.\n", + "\n", + "###### FOV params\n", + "\n", + "- `fov_size`: The length of one side of your FOV in μm\n", + "- `min_pixels`: The minimum number of pixels required in a segmented object\n", + "- `max_pixels`: The maximum number of pixels required in a segmented object\n", + "\n", + "A text log will be saved with the values used to segment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "set_ez_seg_params" + ] + }, + "outputs": [], + "source": [ + "# channel params\n", + "channel_to_segment = \"astrocyte\"\n", + "channel_to_segment_path = composite_dir\n", + "path_sub_folder_name = None\n", + "\n", + "# mask params\n", + "mask_name = \"astrocyte-arms\"\n", + "object_shape = \"projection\"\n", + "\n", + "# blur/threshold params\n", + "blur = 1\n", + "threshold = 99\n", + "hole_size = \"auto\"\n", + "\n", + "# fov params\n", + "fov_size = 400\n", + "min_pixels = 100\n", + "max_pixels = 100000" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Create your object masks & view a test image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "gen_obj_masks" + ] + }, + "outputs": [], + "source": [ + "# Segment images.\n", + "ez_object_segmentation.create_object_masks(\n", + " image_data_dir=channel_to_segment_path,\n", + " img_sub_folder=path_sub_folder_name,\n", + " fov_list=fovs,\n", + " channel_to_segment=channel_to_segment,\n", + " masks_dir=ez_masks_dir,\n", + " log_dir=log_dir,\n", + " mask_name=mask_name,\n", + " object_shape_type=object_shape,\n", + " sigma=blur,\n", + " thresh=threshold,\n", + " hole_size=hole_size,\n", + " fov_dim=fov_size,\n", + " min_object_area=min_pixels,\n", + " max_object_area=max_pixels,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### View a mask test image\n", + "- `fov_name`: Specify which FoV you'd like to see for visual testing purposes.\n", + "- `channel_to_view`: This should be the channel or composite name you segmented upon above.\n", + "- `channel_to_view_dir`: The directory (usually composite or tiff) your channel_to_view resides.\n", + "- `path_sub_folder_name`: Set to the subfolder name your channel to segment is located in, otherwise write `None`.\n", + "- `mask_to_view`: This should be the mask name you specified above.\n", + "- `mask_to_view_path`: The directory (usually ez_masks_dir) your masks_to_view resides." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "view_obj_mask" + ] + }, + "outputs": [], + "source": [ + "fov_name = \"fov1\"\n", + "channel_to_view = \"astrocyte\"\n", + "channel_to_view_dir = composite_dir\n", + "path_sub_folder_name = None\n", + "mask_to_view = \"astrocyte-arms\"\n", + "mask_to_view_dir = ez_masks_dir\n", + "\n", + "# Show test segmentation image\n", + "ez_seg_display.overlay_mask_outlines(fov_name, channel_to_view, channel_to_view_dir, path_sub_folder_name, mask_to_view, mask_to_view_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5a. Mask Merger (Optional)\n", + "\n", + "Merging enables connecting traditional circular or oval shaped nucelar-based cell masks with anuclear cell projections (e.g. microglia arms with microglia soma)\n", + "**Note:** Requires the Deepcell outputs from `1_Segment_Image_Data.ipynb`. or another whole_cell segmentation mask." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Here you can merge object segmentation masks with cell masks (or any other type of mask).\n", + "Here you will provide a list of what objects you would like to merge with previously segmented cell masks (or other base mask).\n", + "\n", + "**LIST ORDER IMPORTANT**: The first mask listed will be merged first, the second mask with cells not merged during the first merge, etc.\n", + "\n", + "Additionally, set the percent area of an object that needs to be overlapping onto a cell mask to get merged.\n", + "\n", + "* `merge_masks_list`: list of object masks to merge to the base (cell) `image.List` of object masks to merge to the base (cell) image.\n", + "* `percent_overlap`: percent threshold required for a cell mask to be merged into an object mask\n", + "* `cell_dir`: the final mask directory\n", + "* `merged_masks_dir`: the directory to store the merged masks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "set_mask_dirs" + ] + }, + "outputs": [], + "source": [ + "merge_masks_list = [\"microglia-arms\", \"astrocyte-arms\"]\n", + "percent_overlap = 30\n", + "\n", + "# Overwrite if different from above\n", + "cell_dir = os.path.join(segmentation_dir, \"deepcell_output\")\n", + "merged_masks_dir = os.path.join(segmentation_dir, \"merged_masks_dir\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "validate_mask_dirs" + ] + }, + "outputs": [], + "source": [ + "# validate paths\n", + "io_utils.validate_paths(\n", + " [\n", + " cell_dir,\n", + " merged_masks_dir\n", + " ]\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "#### MERGE ez segmentation & whole_cell masks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "merge_seg_cell" + ] + }, + "outputs": [], + "source": [ + "# Merge your masks across all FoVs\n", + "merge_masks.merge_masks_seq(\n", + " fov_list=fovs,\n", + " object_list=merge_masks_list,\n", + " object_mask_dir=ez_masks_dir,\n", + " cell_mask_path=cell_dir,\n", + " overlap_percent_threshold=percent_overlap,\n", + " save_path=merged_masks_dir,\n", + " log_dir=log_dir\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### View a merged mask test image\n", + "- `fov_name`: Specify which FoV you'd like to see for visual testing purposes.\n", + "- `merge_mask_view`: This should be one of the object (i.e. ez) mask names you merged upon above.\n", + "- `object_mask_dir`: The directory (usually ez_masks_dir) your object mask resides.\n", + "- `cell_mask_dir`: The directory (usually cell_dir) your cell mask resides.\n", + "- `merged_mask_dir`: The directory your merged_mask resides." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "view_merged_mask" + ] + }, + "outputs": [], + "source": [ + "fov_name = \"fov3\"\n", + "merge_mask_view = \"microglia-arms\"\n", + "object_mask_dir = ez_masks_dir\n", + "cell_mask_dir = cell_dir\n", + "merged_mask_dir = merged_masks_dir\n", + "\n", + "# Show test composite image\n", + "ez_seg_display.multiple_mask_display(fov_name, merge_mask_view, object_mask_dir, cell_mask_dir, merged_mask_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5b. Consolidate masks (Optional - Needed if using merged and unmerged objects in analysis)\n", + "\n", + "Run this step to ensure all masks are in the same folder if you want to combine merged and non-merged mask sources, e.g. merged cells and proteopathy objects." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "tags": [ + "consolidate_mask" + ] + }, + "outputs": [], + "source": [ + "# Enter the names of masks you would like to include in the final mask directory, e.g. [\"merged\", \"final_cells_remaining\", \"plaques\", \"tangles\"].\n", + "mask_names = [\"merged\", \"final_cells_remaining\", \"plaques\"]\n", + "\n", + "# Name of the final mask destination folder\n", + "final_mask_dir = os.path.join(segmentation_dir, \"final_mask_dir\")\n", + "\n", + "if not os.path.exists(final_mask_dir):\n", + " os.makedirs(final_mask_dir)\n", + "\n", + "# Create and fill the final mask folder.\n", + "ez_seg_utils.find_and_copy_files(mask_names, segmentation_dir, final_mask_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Relabel all masks\n", + "\n", + "Run this step to ensure all mask ids across all segmentations (deepcell, ez, or other inputs) are relabeled from 1 to n total masks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "tags": [ + "relabel_mask" + ] + }, + "outputs": [], + "source": [ + "# Enter the root directory where all masks are located that you want to include in your cell table. E.g. final_mask_dir, or merged_masks_dir\n", + "root_mask_dir = final_mask_dir\n", + "\n", + "# Run this cell to relabel all masks in all folders.\n", + "ez_seg_utils.renumber_masks(root_mask_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7b. Generate single cell and/or object expression table" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "nuc_props_set" + ] + }, + "outputs": [], + "source": [ + "# set to True to bypass expensive cell or object property calculations\n", + "# only cell or object label, size, and centroid will be extracted if True\n", + "fast_extraction = False\n", + "\n", + "# Override to give your cell label an alternative name (e.g. plaques.csv)\n", + "table_name = \"cell_and_objects\"\n", + "\n", + "# set to True to add nuclear cell properties to the expression matrix\n", + "nuclear_counts = False" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For a full list of features extracted, please refer to the cell table section of: https://ark-analysis.readthedocs.io/en/latest/_rtd/data_types.html\n", + "\n", + "**NOTE: if you're loading your own dataset, please make sure all the imaging data is in the same folder with each fov given its own folder and all fovs having the same channels.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "create_exp_mat" + ] + }, + "outputs": [], + "source": [ + "# combine any merged objects, any remaining unmerged cell-masks, and any remaining object masks which were not included.\n", + "(\n", + " cell_table_size_normalized,\n", + " cell_table_arcsinh_transformed,\n", + ") = marker_quantification.generate_cell_table(\n", + " segmentation_dir=root_mask_dir,\n", + " tiff_dir=image_data_dir,\n", + " img_sub_folder=None,\n", + " fovs=fovs,\n", + " batch_size=5,\n", + " nuclear_counts=nuclear_counts,\n", + " fast_extraction=fast_extraction,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "save_exp_mat" + ] + }, + "outputs": [], + "source": [ + "# Set the compression level if desired, ZSTD compression can offer up to a 60-70% reduction in file size.\n", + "# NOTE: Compressed `csv` files cannot be opened in Excel. They must be uncompressed beforehand.\n", + "compression = None\n", + "\n", + "# Uncomment the line below to allow for compressed `csv` files.\n", + "# compression = {\"method\": \"zstd\", \"level\": 3}\n", + "\n", + "cell_table_size_normalized.to_csv(\n", + " os.path.join(cell_table_dir, table_name + \"_table_size_normalized.csv\"),\n", + " compression=compression,\n", + " index=False,\n", + ")\n", + "cell_table_arcsinh_transformed.to_csv(\n", + " os.path.join(cell_table_dir, table_name + \"_table_arcsinh_transformed.csv\"),\n", + " compression=compression,\n", + " index=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If desired, save split CSV's based upon mask name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "save_mat_by_mask" + ] + }, + "outputs": [], + "source": [ + "ez_seg_utils.filter_csvs_by_mask(csv_path_name=cell_table_dir, csv_substr_replace=table_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Create a Mantis-Viewer friendly set of masks. (Optional)\n", + "\n", + "If you would like your masks to be able to be viewed in mantis viewer, you can use the code below to create a file structure that will arrange the masks in a way they can easily be uploaded to the viewer.\n", + "NOTE: Only one mask can be viewed at a time in Mantis, so reloading the mask in the project options will be necessary." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "mantis_mask" + ] + }, + "outputs": [], + "source": [ + "# mantis\n", + "fovs = io_utils.list_folders(image_data_dir)\n", + "\n", + "# Change segmentation directory to merged_masks_dir if merging performed or just ez_mask_dir if no merging done.\n", + "ez_seg_utils.create_mantis_project(\n", + " fovs=fovs,\n", + " image_data_dir=image_data_dir,\n", + "\n", + " segmentation_dir=segmentation_dir,\n", + " mantis_dir=ez_visualization_dir,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tests/segmentation/ez_seg/composites_test.py b/tests/segmentation/ez_seg/composites_test.py new file mode 100644 index 000000000..9614a404b --- /dev/null +++ b/tests/segmentation/ez_seg/composites_test.py @@ -0,0 +1,160 @@ +import numpy as np +import os +import pathlib +import pytest +import skimage.io as io +import xarray as xr + +from alpineer.load_utils import load_imgs_from_tree +from ark.segmentation.ez_seg import composites + + +@pytest.fixture(scope="session") +def image_dir(tmpdir_factory: pytest.TempPathFactory) -> pathlib.Path: + image_dir_name: pathlib.Path = tmpdir_factory.mktemp("image_dir") + fovs: List[str] = [f"fov{i}" for i in np.arange(3)] + chans: List[str] = [f"chan{i}" for i in np.arange(2)] + + example_img_0 = np.array( + [[0] * 4, + [1] * 4, + [2] * 4, + [3] * 4] + ) + example_img_1 = np.array( + [[0, 0, 1, 1], + [1, 1, 2, 2], + [2, 2, 3, 3], + [3, 3, 4, 4]] + ) + example_imgs = [example_img_0, example_img_1] + + for fov in fovs: + fov_dir: pathlib.Path = image_dir_name / fov + os.mkdir(fov_dir) + for i, chan in enumerate(chans): + io.imsave(str(fov_dir / chan + ".tiff"), example_imgs[i]) + + yield image_dir_name + + +@pytest.fixture(scope="session") +def image_data(image_dir: pathlib.Path) -> xr.DataArray: + yield load_imgs_from_tree( + data_dir=image_dir, img_sub_folder=None, fovs=["fov0"] + ) + + +@pytest.fixture(scope="session") +def composite_array_add() -> np.ndarray: + yield np.array( + [[0] * 4, + [1] * 4, + [2] * 4, + [3] * 4] + ) + + +@pytest.fixture(scope="session") +def composite_array_subtract() -> np.ndarray: + yield np.array( + [[3] * 4, + [2] * 4, + [1] * 4, + [0] * 4] + ) + + +def test_add_to_composite_signal(image_data: xr.DataArray, composite_array_add: np.ndarray): + composite_array_added: np.ndarray = composites.add_to_composite( + data=image_data, + composite_array=composite_array_add, + images_to_add=["chan0", "chan1"], + image_type="signal", + composite_method="total" + ) + + result: np.ndarray = np.array( + [[0, 0, 1, 1], + [2, 2, 3, 3], + [4, 4, 5, 5], + [6, 6, 7, 7]] + ) + assert np.all(composite_array_added == result) + + +def test_add_to_composite_signal_binary(image_data: xr.DataArray, composite_array_add: np.ndarray): + composite_array_added: np.ndarray = composites.add_to_composite( + data=image_data, + composite_array=composite_array_add, + images_to_add=["chan0", "chan1"], + image_type="signal", + composite_method="binary" + ) + + result: np.ndarray = np.array( + [[0, 0, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1]] + ) + assert np.all(composite_array_added == result) + + +def test_add_to_composite_pixel_cluster(image_data: xr.DataArray, composite_array_add: np.ndarray): + composite_array_added: np.ndarray = composites.add_to_composite( + data=image_data, + composite_array=composite_array_add, + images_to_add=["chan0"], + image_type="pixel_cluster", + composite_method="binary" + ) + + result: np.ndarray = np.array( + [[0, 0, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1]] + ) + assert np.all(composite_array_added == result) + + +def test_subtract_from_composite_signal_binary( + image_data: xr.DataArray, composite_array_subtract: np.ndarray +): + composite_array_subtracted: np.ndarray = composites.subtract_from_composite( + data=image_data, + composite_array=composite_array_subtract.copy(), + images_to_subtract=["chan0", "chan1"], + image_type="signal", + composite_method="binary" + ) + + result: np.ndarray = np.array( + [[1, 1, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]] + ) + assert np.all(composite_array_subtracted == result) + + +def test_subtract_from_composite_general( + image_data: xr.DataArray, composite_array_subtract: np.ndarray +): + # also handles other casees that aren't image_type="signal" + composite_method="binary" + composite_array_subtracted: np.ndarray = composites.subtract_from_composite( + data=image_data, + composite_array=composite_array_subtract.copy(), + images_to_subtract=["chan1"], + image_type="pixel_cluster", + composite_method="total" + ) + + result: np.ndarray = np.array( + [[3, 3, 2, 2], + [1, 1, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]] + ) + assert np.all(composite_array_subtracted == result) diff --git a/tests/segmentation/ez_seg/ez_object_segmentation_test.py b/tests/segmentation/ez_seg/ez_object_segmentation_test.py new file mode 100644 index 000000000..09c37137a --- /dev/null +++ b/tests/segmentation/ez_seg/ez_object_segmentation_test.py @@ -0,0 +1,236 @@ +import pathlib +from scipy import ndimage +from ark.segmentation.ez_seg import ez_object_segmentation +import pytest +from pytest_cases import param_fixture +import numpy as np +from skimage import draw +import xarray as xr +from skimage.io import imread +from alpineer import image_utils +from skimage.util import img_as_int +from sklearn.preprocessing import minmax_scale + + +@pytest.fixture(scope="session") +def ez_fov( + tmp_path_factory: pytest.TempPathFactory, rng: np.random.Generator +) -> pathlib.Path: + """ + Creates 2 FOVs with 3 channels each, with 60 spots per channel. + + data + ├── ez_seg_masks + └── image_data + ├── fov_0 + │ ├── chan_0.tiff + │ ├── chan_1.tiff + │ └── chan_2.tiff + └── fov_1 + ├── chan_0.tiff + ├── chan_1.tiff + └── chan_2.tiff + + Yields: + pathlib.Path: The path to the temporary directory containing the image_data + and the masks directory. + """ + fov_count: int = 2 + channel_count: int = 3 + image_size: int = 1024 + spot_count: int = 60 + spot_radius: int = 40 + cloud_noise_size: int = 4 + + image: np.ndarray = rng.normal( + loc=0.25, scale=0.25, size=(channel_count, image_size, image_size) + ) + output_image: np.ndarray = np.zeros_like(a=image, dtype=np.int64) + + # Make temporary path + tmp_path: pathlib.Path = tmp_path_factory.mktemp("data") + + # Make the temporary directory for the image data + tmp_image_dir = tmp_path / "image_data" + tmp_image_dir.mkdir(parents=True, exist_ok=True) + + # Make the temporary directory for the masks dir + tmp_masks_dir = tmp_path / "ez_seg_masks" + tmp_masks_dir.mkdir(parents=True, exist_ok=True) + + # Make the temporary directory for the log dir + tmp_log_dir = tmp_path / "ez_logs" + tmp_log_dir.mkdir(parents=True, exist_ok=True) + + for fov_idx in range(fov_count): + for channel_idx in range(channel_count): + channel: np.ndarray = image[channel_idx] + for _ in range(spot_count): + rr, cc = draw.disk( + center=( + rng.integers(channel.shape[0]), + rng.integers(channel.shape[1]), + ), + radius=spot_radius, + shape=channel.shape, + ) + channel[rr, cc] = 1 + + channel *= rng.normal(loc=1.0, scale=0.1, size=channel.shape) + + channel *= ndimage.zoom( + rng.normal( + loc=1.0, scale=0.5, size=(cloud_noise_size, cloud_noise_size) + ), + image_size / cloud_noise_size, + ) + + int_channel = img_as_int( + minmax_scale( + ndimage.gaussian_filter( + channel, sigma=2.0 + ), + feature_range=(-1., 0.999) + ) + ).astype(np.int64) + + output_image[channel_idx]: np.ndarray = int_channel + + fov_dir = tmp_image_dir / f"fov_{fov_idx}" + fov_dir.mkdir(parents=True, exist_ok=True) + + for idx, output_channel in enumerate(output_image): + image_utils.save_image( + fname=fov_dir / f"chan_{idx}.tiff", data=output_channel + ) + + yield tmp_path + + +@pytest.mark.parametrize( + "_min_object_area, _max_object_area, _object_shape_type, _thresh, _fov_dim", + [(100, 100000, "blob", None, 400), (200, 2000, "projection", 20, 800)], +) +def test_create_object_masks( + ez_fov: pathlib.Path, + _object_shape_type: str, + _thresh: float, + _fov_dim: int, + _min_object_area: int, + _max_object_area: int, +) -> None: + _sigma = 1 + _hole_size = 10 + + with pytest.raises(ValueError): + ez_object_segmentation.create_object_masks( + image_data_dir=ez_fov / "image_data", + img_sub_folder="wrong_sub_folder", + fov_list=["fov_0", "fov_1"], + mask_name="test_mask", + object_shape_type="wrong_shape", + channel_to_segment="chan_0", + masks_dir=ez_fov / "ez_seg_masks", + log_dir=ez_fov / "ez_logs", + ) + with pytest.raises(FileNotFoundError): + ez_object_segmentation.create_object_masks( + image_data_dir="wrong_path", + img_sub_folder="wrong_sub_folder", + fov_list=["fov_0", "fov_1"], + mask_name="test_mask", + object_shape_type="blob", + channel_to_segment="chan_0", + masks_dir=ez_fov / "ez_seg_masks", + log_dir=ez_fov / "ez_logs", + ) + # Test the function (succeeds) + ez_object_segmentation.create_object_masks( + image_data_dir=ez_fov / "image_data", + img_sub_folder=None, + fov_list=["fov_0", "fov_1"], + mask_name="test_mask", + object_shape_type=_object_shape_type, + channel_to_segment="chan_0", + masks_dir=ez_fov / "ez_seg_masks", + log_dir=ez_fov / "ez_logs", + sigma=_sigma, + thresh=_thresh, + hole_size=_hole_size, + fov_dim=_fov_dim, + min_object_area=_min_object_area, + max_object_area=_max_object_area, + ) + assert (ez_fov / "ez_seg_masks" / "fov_0_test_mask.tiff").exists() + assert (ez_fov / "ez_seg_masks" / "fov_1_test_mask.tiff").exists() + assert (ez_fov / "ez_logs" / "test_mask_segmentation_log.txt").exists() + with open(ez_fov / "ez_logs" / "test_mask_segmentation_log.txt", "r") as f: + log_contents = f.read() + assert "fov_0" in log_contents + assert "fov_1" in log_contents + assert "test_mask" in log_contents + assert "chan_0" in log_contents + assert _object_shape_type in log_contents + assert str(_hole_size) in log_contents + assert str(_sigma) in log_contents + assert str(_thresh) in log_contents + assert str(_fov_dim) in log_contents + assert str(_min_object_area) in log_contents + assert str(_max_object_area) in log_contents + + +@pytest.mark.parametrize( + "_min_object_area, _max_object_area, _object_shape_type, _thresh, _fov_dim", + [(100, 100000, "blob", None, 400), (200, 2000, "projection", 100, 800)], +) +def test_create_object_mask( + ez_fov: pathlib.Path, + _object_shape_type: str, + _thresh: float, + _fov_dim: int, + _min_object_area: int, + _max_object_area: int, +) -> None: + fov0_chan0: np.ndarray = imread(ez_fov / "image_data" / "fov_0" / "chan_0.tiff") + + object_mask = ez_object_segmentation._create_object_mask( + input_image=xr.DataArray(fov0_chan0), + object_shape_type=_object_shape_type, + sigma=1, + thresh=_thresh, + hole_size=10, + fov_dim=_fov_dim, + min_object_area=_min_object_area, + max_object_area=_max_object_area, + ) + + assert object_mask.shape == fov0_chan0.shape + assert object_mask.dtype == np.int32 + + +@pytest.mark.parametrize( + "_block_type, _fov_dim, _img_shape, _block_size", + [("small_holes", 400, 512, 316), ("local_thresh", 800, 1024, 13)] +) +def test_get_block_size( + _block_type: str, + _fov_dim: int, + _img_shape: int, + _block_size: int, +) -> None: + assert isinstance( + ez_object_segmentation.get_block_size( + block_type=_block_type, fov_dim=_fov_dim, img_shape=_img_shape + ), + int, + ) + # Fails with an invalid `block_type` + with pytest.raises(ValueError): + ez_object_segmentation.get_block_size( + block_type="incorrect_block_type", fov_dim=_fov_dim, img_shape=_img_shape + ) + assert ez_object_segmentation.get_block_size( + block_type=_block_type, + fov_dim=_fov_dim, + img_shape=_img_shape, + ) == _block_size diff --git a/tests/segmentation/ez_seg/ez_seg_display_test.py b/tests/segmentation/ez_seg/ez_seg_display_test.py new file mode 100644 index 000000000..fce13ce34 --- /dev/null +++ b/tests/segmentation/ez_seg/ez_seg_display_test.py @@ -0,0 +1,145 @@ +import pathlib +from dataclasses import dataclass + +import numpy as np +import pytest +from alpineer import image_utils +from skimage.io import imread + +from ark.segmentation.ez_seg import ez_seg_display + + +@dataclass +class MaskDataPaths: + image_data_path: pathlib.Path + mask_path: pathlib.Path + fov0_dir: pathlib.Path + object_mask_dir: pathlib.Path + cell_mask_dir: pathlib.Path + merged_mask_dir: pathlib.Path + + +@pytest.fixture(scope="module") +def mask_data( + tmp_path_factory: pytest.TempPathFactory, rng: np.random.Generator +) -> tuple[pathlib.Path, pathlib.Path]: + img_data_path = tmp_path_factory.mktemp("image_data") + mask_path = tmp_path_factory.mktemp("mask_data") + + fov0 = "fov_0" + mask0 = "mask_0" + + # Create directories + fov0_dir = img_data_path / fov0 + object_mask_dir = mask_path / "object_mask_dir" + cell_mask_dir = mask_path / "cell_mask_dir" + merged_mask_dir = mask_path / "merged_mask_dir" + + for p in [fov0_dir, object_mask_dir, cell_mask_dir, merged_mask_dir]: + p.mkdir(parents=True, exist_ok=True) + + fov0_chan0_img = fov0_dir / "chan_0.tiff" + object_mask_img = object_mask_dir / f"{fov0}_{mask0}.tiff" + cell_mask_img = cell_mask_dir / f"{fov0}_whole_cell.tiff" + merged_mask_img = merged_mask_dir / f"{fov0}_{mask0}_merged.tiff" + + image_utils.save_image(fname=fov0_chan0_img, data=rng.random(size=(1024, 1024))) + image_utils.save_image(fname=object_mask_img, data=rng.random(size=(1024, 1024))) + image_utils.save_image(fname=cell_mask_img, data=rng.random(size=(1024, 1024))) + image_utils.save_image(fname=merged_mask_img, data=rng.random(size=(1024, 1024))) + + yield MaskDataPaths( + image_data_path=img_data_path, + mask_path=mask_path, + fov0_dir=fov0_dir, + object_mask_dir=object_mask_dir, + cell_mask_dir=cell_mask_dir, + merged_mask_dir=merged_mask_dir, + ) + + +def test_display_channel_image(mask_data: MaskDataPaths): + ez_seg_display.display_channel_image( + base_image_path=mask_data.image_data_path, + sub_folder_name=None, + test_fov_name="fov_0", + channel_name="chan_0", + ) + + with pytest.raises(FileNotFoundError): + ez_seg_display.display_channel_image( + base_image_path=mask_data.image_data_path, + sub_folder_name=None, + test_fov_name="fov_0", + channel_name="bad_chan_name", + ) + + +def test_overlay_mask_outlines(mask_data: MaskDataPaths): + ez_seg_display.overlay_mask_outlines( + fov="fov_0", + channel="chan_0", + image_dir=mask_data.image_data_path, + sub_folder_name=None, + mask_name="mask_0", + mask_dir=mask_data.object_mask_dir, + ) + + with pytest.raises(FileNotFoundError): + ez_seg_display.overlay_mask_outlines( + fov="fov_0", + channel="chan_0", + image_dir=mask_data.image_data_path, + sub_folder_name=None, + mask_name="bad_mask_name", + mask_dir=mask_data.object_mask_dir, + ) + + with pytest.raises(FileNotFoundError): + ez_seg_display.overlay_mask_outlines( + fov="fov_0", + channel="bad_chan_name", + image_dir=mask_data.image_data_path, + sub_folder_name=None, + mask_name="mask_0", + mask_dir=mask_data.object_mask_dir, + ) + + +def test_multiple_mask_display(mask_data: MaskDataPaths): + fov0 = "fov_0" + mask0 = "mask_0" + + ez_seg_display.multiple_mask_display( + fov=fov0, + mask_name=mask0, + object_mask_dir=mask_data.object_mask_dir, + cell_mask_dir=mask_data.cell_mask_dir, + merged_mask_dir=mask_data.merged_mask_dir, + ) + + with pytest.raises(FileNotFoundError): + ez_seg_display.multiple_mask_display( + fov=fov0, + mask_name="bad_mask_name", + object_mask_dir=mask_data.object_mask_dir, + cell_mask_dir=mask_data.cell_mask_dir, + merged_mask_dir=mask_data.merged_mask_dir, + ) + + +def test_create_overlap_and_merge_visual(mask_data: MaskDataPaths): + overlap_visual: np.ndarray = ez_seg_display.create_overlap_and_merge_visual( + fov="fov_0", + mask_name="mask_0", + object_mask_dir=mask_data.object_mask_dir, + cell_mask_dir=mask_data.cell_mask_dir, + merged_mask_dir=mask_data.merged_mask_dir, + ) + + assert ( + overlap_visual.shape[:2] + == imread(mask_data.image_data_path / "fov_0" / "chan_0.tiff").shape + ) + assert overlap_visual.shape[-1] == 3 # rgb channels + assert overlap_visual.dtype == np.uint8 diff --git a/tests/segmentation/ez_seg/ez_seg_utils_test.py b/tests/segmentation/ez_seg/ez_seg_utils_test.py new file mode 100644 index 000000000..6dfb53c71 --- /dev/null +++ b/tests/segmentation/ez_seg/ez_seg_utils_test.py @@ -0,0 +1,191 @@ +import numpy as np +import os +import pandas as pd +import pathlib +import pytest +import skimage.io as io +import tempfile + +from alpineer import io_utils, image_utils +from ark.segmentation.ez_seg import ez_seg_utils +from typing import List, Union + + +@pytest.fixture(scope="module") +def tiff_dir(tmpdir_factory: pytest.TempPathFactory) -> pathlib.Path: + tiff_dir_name: pathlib.Path = tmpdir_factory.mktemp("tiff_dir") + num_fovs: int = 3 + + for nf in range(num_fovs): + os.mkdir(tiff_dir_name / f"fov{nf}") + img_path: pathlib.Path = tiff_dir_name / f"fov{nf}" / f"fov{nf}.tiff" + tiff_data: np.ndarray = np.random.rand(32, 32) + image_utils.save_image(fname=img_path, data=tiff_data) + + yield tiff_dir_name + + +@pytest.fixture(scope="module") +def seg_dir(tmpdir_factory: pytest.TempPathFactory) -> pathlib.Path: + seg_dir_name: pathlib.Path = tmpdir_factory.mktemp("seg_dir") + seg_subdirs: List[str] = ["seg1", "seg2"] + num_fovs: int = 3 + + for ss in seg_subdirs: + os.mkdir(seg_dir_name / ss) + + for nf in range(num_fovs): + mask_path: pathlib.Path = seg_dir_name / ss / f"fov{nf}_{ss}.tiff" + mask_data: np.ndarray = np.random.randint(0, 2, (32, 32)) + image_utils.save_image(fname=mask_path, data=mask_data) + + yield seg_dir_name + + +@pytest.fixture(scope="module") +def mantis_dir(tmpdir_factory: pytest.TempPathFactory) -> pathlib.Path: + mantis_dir_name: pathlib.Path = tmpdir_factory.mktemp("mantis_dir") + yield mantis_dir_name + + +@pytest.fixture(scope="module") +def mask_dir(tmpdir_factory: pytest.TempPathFactory) -> pathlib.Path: + mask_dir_name: pathlib.Path = tmpdir_factory.mktemp("mask_dir") + + mask_suffixes: List[str] = ["type1.tiff", "type2.tiff"] + fov_count: int = 3 + + mask_files: List[pathlib.Path] = [ + mask_dir_name / f"fov{fov_num}_{ms}" + for fov_num in range(fov_count) + for ms in mask_suffixes + ] + + for mf in mask_files: + mask_data: np.ndarray = np.random.randint(0, 3, (32, 32)) + image_utils.save_image(fname=mf, data=mask_data) + + yield mask_dir_name + + +@pytest.fixture(scope="module") +def nested_mask_dir(tmpdir_factory: pytest.TempPathFactory) -> pathlib.Path: + nested_mask_dir_name: pathlib.Path = tmpdir_factory.mktemp("nested_mask_dir") + + mask_suffixes: List[str] = ["type1.tiff", "type2.tiff", "type3.tiff"] + fov_count: int = 3 + + nested_mask_subdir = nested_mask_dir_name / "nested_mask_subdir" + os.makedirs(nested_mask_subdir) + + mask_files_root: List[pathlib.Path] = [ + nested_mask_dir_name / f"fov{fov_num}_{ms}" + for fov_num in range(4) + for ms in mask_suffixes + ] + mask_files_subdir: List[pathlib.Path] = [ + nested_mask_subdir / f"fov{fov_num}_{ms}" + for fov_num in range(4, 7) + for ms in mask_suffixes + ] + all_mask_files: List[pathlib.Path] = mask_files_root + mask_files_subdir + + for mf in all_mask_files: + mask_data: np.ndarray = np.random.randint(0, 3, (32, 32)) + image_utils.save_image(fname=mf, data=mask_data) + + yield nested_mask_dir_name + + +def test_find_and_copy_files( + tmpdir_factory: pytest.TempPathFactory, nested_mask_dir: pathlib.Path +): + combined_mask_dir: pathlib.Path = tmpdir_factory.mktemp("mask_dest_dir") + mask_suffix_names: List[str] = ["type1", "type2"] + + ez_seg_utils.find_and_copy_files(mask_suffix_names, nested_mask_dir, combined_mask_dir) + + files_copied = [ + combined_mask_dir / f"fov{fov_num}_{ms}.tiff" + for fov_num in range(7) + for ms in mask_suffix_names + ] + assert all([os.path.exists(fc) for fc in files_copied]) + + +def test_renumber_masks(mask_dir: pathlib.Path): + ez_seg_utils.renumber_masks(mask_dir) + mask_files = io_utils.list_files(mask_dir, substrs=".tiff") + + cluster_start: int = len(mask_files) * 2 + 1 + max_cluster: int = cluster_start + len(mask_files) * 2 - 1 + all_clusters_seen: np.ndarray = np.array([]) + + for i, mf in enumerate(mask_files): + mask_data: np.ndarray = io.imread(str(mask_dir / mf)) + mask_clusters: np.ndarray = np.unique(mask_data[mask_data > 0]) + all_clusters_seen = np.concatenate([all_clusters_seen, mask_clusters]) + + assert np.all(np.sort(all_clusters_seen) == np.arange(cluster_start, max_cluster + 1)) + + +def test_create_mantis_project(tiff_dir: pathlib.Path, seg_dir: pathlib.Path, + mantis_dir: pathlib.Path): + fovs: List[str] = [f"fov{f}" for f in range(3)] + ez_seg_utils.create_mantis_project( + fovs, + tiff_dir, + seg_dir, + mantis_dir + ) + + for fov in fovs: + expected_files: List[str] = [] + expected_files.extend(io_utils.list_files(tiff_dir / fov, substrs=".tiff")) + for seg_subdir in io_utils.list_folders(seg_dir): + expected_files.extend(io_utils.list_files(seg_dir / seg_subdir, substrs=fov)) + + mantis_files: List[str] = io_utils.list_files(mantis_dir / fov) + assert set(expected_files) == set(mantis_files) + + +def test_log_creator(): + with tempfile.TemporaryDirectory() as td: + log_dir: Union[str, pathlib.Path] = os.path.join(td, "log_dir") + os.mkdir(log_dir) + + variables_to_log: dict[str, any] = {"var1": "val1", "var2": 2} + ez_seg_utils.log_creator(variables_to_log, log_dir) + + with open(os.path.join(log_dir, "config_values.txt"), "r") as infile: + log_lines: List[str] = infile.readlines() + + assert log_lines[0] == "var1: val1\n" + assert log_lines[1] == "var2: 2\n" + + +def test_filter_csvs_by_mask(): + with tempfile.TemporaryDirectory() as td: + csv_dir: Union[str, pathlib.Path] = os.path.join(td, "csv_dir") + os.mkdir(csv_dir) + + table_names: List[str] = [f"table{i}" for i in np.arange(2)] + mask_names: List[str] = [f"mask{i}" for i in np.arange(2)] + + for tn in table_names: + sample_data: pd.DataFrame = pd.DataFrame(np.random.rand(6, 3)) + sample_data["mask_type"]: pd.Series = [mask_names[0]] * 3 + [mask_names[1]] * 3 + sample_data.to_csv(os.path.join(csv_dir, tn + "_replace.csv"), index=False) + + ez_seg_utils.filter_csvs_by_mask(csv_dir, "_replace") + num_total_files = len(table_names) * len(mask_names) + assert len(io_utils.list_files(csv_dir, substrs="filtered_")) == num_total_files + + for tn in table_names: + for mn in mask_names: + created_csv = f"filtered_{mn}{tn}.csv" + assert os.path.exists(os.path.join(csv_dir, created_csv)) + + csv_data = pd.read_csv(os.path.join(csv_dir, created_csv)) + assert csv_data.shape == (3, 4) + assert np.all(csv_data["mask_type"].values == mn) diff --git a/tests/segmentation/ez_seg/merge_masks_test.py b/tests/segmentation/ez_seg/merge_masks_test.py new file mode 100644 index 000000000..90c680f89 --- /dev/null +++ b/tests/segmentation/ez_seg/merge_masks_test.py @@ -0,0 +1,115 @@ +import numpy as np +import os +import pathlib +import skimage.io as io +import tempfile +import xarray as xr + +from alpineer import io_utils +from ark.segmentation.ez_seg import merge_masks +from scipy.ndimage import label +from skimage.draw import disk +from typing import List, Union + + +def test_merge_masks_seq(): + fov_list: List[str] = [f"fov{i}" for i in range(3)] + object_list: List[str] = [f"mask{i}" for i in range(2)] + + with tempfile.TemporaryDirectory() as td: + object_mask_dir: Union[str, pathlib.Path] = os.path.join(td, "ez_seg_dir") + cell_mask_dir: Union[str, pathlib.Path] = os.path.join(td, "deepcell_output") + merged_mask_dir: Union[str, pathlib.Path] = os.path.join(td, "merged_masks_dir") + log_dir: Union[str, pathlib.Path] = os.path.join(td, "log_dir") + for directory in [object_mask_dir, cell_mask_dir, merged_mask_dir, log_dir]: + os.mkdir(directory) + + overlap_thresh: int = 10 + + for fov in fov_list: + cell_mask_data: np.ndarray = np.random.randint(0, 16, (32, 32)) + cell_mask_fov_file: Union[str, pathlib.Path] = os.path.join( + cell_mask_dir, f"{fov}_whole_cell.tiff" + ) + io.imsave(cell_mask_fov_file, cell_mask_data) + + for obj in object_list: + object_mask_data: np.ndarray = np.random.randint(0, 8, (32, 32)) + object_mask_fov_file: Union[str, pathlib.Path] = os.path.join( + object_mask_dir, f"{fov}_{obj}.tiff" + ) + io.imsave(object_mask_fov_file, cell_mask_data) + + # we're only testing functionality, for in-depth merge testing see test_merge_masks_single + merge_masks.merge_masks_seq( + fov_list, object_list, object_mask_dir, cell_mask_dir, overlap_thresh, + merged_mask_dir, log_dir + ) + + for fov in fov_list: + merged_mask_fov_file: Union[str, pathlib.Path] = os.path.join( + merged_mask_dir, f"{fov}_final_cells_remaining.tiff" + ) + assert os.path.exists(merged_mask_fov_file) + + log_file: Union[str, pathlib.Path] = os.path.join(log_dir, "mask_merge_log.txt") + assert os.path.exists(log_file) + + with open(log_file) as infile: + log_data: List[str] = infile.readlines() + + assert log_data[0] == f"fov_list: {str(fov_list)}\n" + assert log_data[1] == f"object_list: {str(object_list)}\n" + assert log_data[2] == f"object_mask_dir: {str(object_mask_dir)}\n" + assert log_data[3] == f"cell_mask_path: {str(cell_mask_dir)}\n" + assert log_data[4] == f"overlap_percent_threshold: {str(overlap_thresh)}\n" + assert log_data[5] == f"save_path: {str(merged_mask_dir)}\n" + + +def test_merge_masks_single(): + object_mask: np.ndarray = np.zeros((32, 32)) + cell_mask: np.ndarray = np.zeros((32, 32)) + expected_merged_mask: np.ndarray = np.zeros((32, 32)) + expected_cell_mask: np.ndarray = np.zeros((32, 32)) + + overlap_thresh: int = 10 + merged_mask_name: str = "merged_mask" + + # case 1: overlap below threshold, don't merge + obj1_rows, obj1_cols = disk((7, 7), radius=5, shape=object_mask.shape) + cell1_rows, cell1_cols = disk((1, 1), radius=5, shape=cell_mask.shape) + cell2_rows, cell2_cols = disk((13, 13), radius=5, shape=cell_mask.shape) + object_mask[obj1_rows, obj1_cols] = 1 + cell_mask[cell1_rows, cell1_cols] = 1 + cell_mask[cell2_rows, cell2_cols] = 2 + + # case 2: multiple cells within threshold, only merge best one + obj2_rows, obj2_cols = disk((25, 25), radius=5, shape=object_mask.shape) + cell3_rows, cell3_cols = disk((20, 20), radius=5, shape=cell_mask.shape) + cell4_rows, cell4_cols = disk((27, 27), radius=5, shape=cell_mask.shape) + object_mask[obj2_rows, obj2_cols] = 2 + cell_mask[cell3_rows, cell3_cols] = 3 + cell_mask[cell4_rows, cell4_cols] = 4 + + expected_merged_mask[obj1_rows, obj1_cols] = 1 + expected_merged_mask[obj2_rows, obj2_cols] = 2 + expected_merged_mask[cell4_rows, cell4_cols] = 2 + + expected_cell_mask[cell1_rows, cell1_cols] = 1 + expected_cell_mask[cell2_rows, cell2_cols] = 2 + expected_cell_mask[cell3_rows, cell3_cols] = 3 + + with tempfile.TemporaryDirectory() as td: + mask_save_dir: Union[str, pathlib.Path] = os.path.join(td, "mask_save_dir") + os.mkdir(mask_save_dir) + + created_cell_mask: np.ndarray = merge_masks.merge_masks_single( + object_mask, cell_mask, overlap_thresh, merged_mask_name, mask_save_dir + ) + + created_merged_mask: np.ndarray = io.imread( + os.path.join(mask_save_dir, merged_mask_name + "_merged.tiff") + ) + + assert np.all(created_merged_mask == expected_merged_mask) + assert np.all(created_cell_mask == expected_cell_mask) diff --git a/tests/segmentation/marker_quantification_test.py b/tests/segmentation/marker_quantification_test.py index 2b94926a7..3c1b8cc2f 100644 --- a/tests/segmentation/marker_quantification_test.py +++ b/tests/segmentation/marker_quantification_test.py @@ -736,14 +736,14 @@ def test_generate_cell_table_tree_loading(): nuclear_counts=True) assert norm_data_nuc.shape[0] == norm_data_fov_sub.shape[0] - assert norm_data_nuc.shape[1] == norm_data_fov_sub.shape[1] * 2 + 1 + assert norm_data_nuc.shape[1] == norm_data_fov_sub.shape[1] * 2 misc_utils.verify_in_list( nuclear_col='nc_ratio', nuc_cell_table_cols=norm_data_nuc.columns.values ) assert arcsinh_data_nuc.shape[0] == arcsinh_data_fov_sub.shape[0] - assert arcsinh_data_nuc.shape[1] == norm_data_fov_sub.shape[1] * 2 + 1 + assert arcsinh_data_nuc.shape[1] == norm_data_fov_sub.shape[1] * 2 misc_utils.verify_in_list( nuclear_col='nc_ratio', nuc_cell_table_cols=norm_data_nuc.columns.values @@ -823,14 +823,14 @@ def test_generate_cell_table_mibitiff_loading(): nuclear_counts=True) assert norm_data_nuc.shape[0] == norm_data_fov_sub.shape[0] - assert norm_data_nuc.shape[1] == norm_data_fov_sub.shape[1] * 2 + 1 + assert norm_data_nuc.shape[1] == norm_data_fov_sub.shape[1] * 2 misc_utils.verify_in_list( nuclear_col='nc_ratio', nuc_cell_table_cols=norm_data_nuc.columns.values ) assert arcsinh_data_nuc.shape[0] == arcsinh_data_fov_sub.shape[0] - assert arcsinh_data_nuc.shape[1] == norm_data_fov_sub.shape[1] * 2 + 1 + assert arcsinh_data_nuc.shape[1] == norm_data_fov_sub.shape[1] * 2 misc_utils.verify_in_list( nuclear_col='nc_ratio', nuc_cell_table_cols=norm_data_nuc.columns.values @@ -876,11 +876,15 @@ def test_generate_cell_table_extractions(): img_sub_folder=img_sub_folder, is_mibitiff=False ) - # verify total intensity extraction - assert np.all( - default_norm_data.loc[default_norm_data[settings.CELL_LABEL] == 1][chans].values - == np.arange(9).reshape(3, 3) - ) + # verify total intensity extraction, same for whole_cell and nuclear mask types + for mask_type in ["whole_cell", "nuclear"]: + assert np.all( + default_norm_data.loc[ + (default_norm_data[settings.CELL_LABEL] == 1) & + (default_norm_data["mask_type"] == mask_type) + ][chans].values + == np.arange(9).reshape(3, 3) + ) # define a specific threshold for positive pixel extraction thresh_kwargs = { @@ -897,20 +901,28 @@ def test_generate_cell_table_extractions(): extraction='positive_pixel', **thresh_kwargs ) - assert np.all(positive_pixel_data.iloc[:4][['chan0', 'chan1']].values == 0) - assert np.all(positive_pixel_data.iloc[4:][chans].values == 1) + # only applies for whole_cell mask types + positive_pixel_data_wc = positive_pixel_data[ + positive_pixel_data["mask_type"] == "whole_cell" + ] + assert np.all(positive_pixel_data_wc.iloc[:4][['chan0', 'chan1']].values == 0) + assert np.all(positive_pixel_data_wc.iloc[4:][chans].values == 1) # verify thresh kwarg passes through and nuclear counts True - positive_pixel_data_nuc, _ = marker_quantification.generate_cell_table( + positive_pixel_data, _ = marker_quantification.generate_cell_table( segmentation_dir=temp_dir, tiff_dir=tiff_dir, img_sub_folder=img_sub_folder, is_mibitiff=False, extraction='positive_pixel', nuclear_counts=True, **thresh_kwargs ) + # check explicitly for nuclear mask types + positive_pixel_data_nuc = positive_pixel_data[ + positive_pixel_data["mask_type"] == "nuclear" + ] assert np.all(positive_pixel_data_nuc.iloc[:4][['chan0', 'chan1']].values == 0) assert np.all(positive_pixel_data_nuc.iloc[4:][chans].values == 1) - assert positive_pixel_data_nuc.shape[0] == positive_pixel_data.shape[0] - assert positive_pixel_data_nuc.shape[1] == positive_pixel_data.shape[1] * 2 + 1 + assert positive_pixel_data_nuc.shape[0] == positive_pixel_data.shape[0] / 2 + assert positive_pixel_data_nuc.shape[1] == positive_pixel_data.shape[1] misc_utils.verify_in_list( nuclear_col='nc_ratio', nuc_cell_table_cols=positive_pixel_data_nuc.columns.values diff --git a/tests/utils/data_utils_test.py b/tests/utils/data_utils_test.py index 68920ace2..900a16331 100644 --- a/tests/utils/data_utils_test.py +++ b/tests/utils/data_utils_test.py @@ -126,8 +126,6 @@ def test_fov_mapping(self, _fov: str): # And each FOV should have some background pixels assert fov_mapping_df["label"].min() == 0 - assert set(fov_mapping_df["label"]) == set(self.cmd.mapping["label"]) - def test_cluster_ids(self): assert set(self.cmd.cluster_names) == {"A", "B"} diff --git a/tests/utils/example_dataset_test.py b/tests/utils/example_dataset_test.py index 56babd5fb..9c06372ff 100644 --- a/tests/utils/example_dataset_test.py +++ b/tests/utils/example_dataset_test.py @@ -18,7 +18,8 @@ "LDA_training_inference", "neighborhood_analysis", "pairwise_spatial_enrichment", - "ome_tiff"]) + "ome_tiff", + "ez_seg_data"]) def dataset_download(request, dataset_cache_dir) -> Iterator[ExampleDataset]: """ A Fixture which instantiates and downloads the dataset with respect to each @@ -103,6 +104,52 @@ def _setup(self): self._ome_tiff_files: List[str] = ["fov1.ome"] + self._ez_seg_files = { + "fov_names": [f"fov{i}" for i in range(10)], + "channel_names": ["Ca40", "GFAP", "Synaptophysin", "PanAmyloidbeta1724", + "Na23", "Reelin", "Presenilin1NTF", "Iba1", "CD105", + "C12", "EEA1", "VGLUT1", "PolyubiK63", "Ta181", "Au197", + "Si28", "PanGAD6567", "CD33Lyo", "MAP2", "Calretinin", + "PolyubiK48", "MAG", "TotalTau", "Amyloidbeta140", + "Background", "CD45", "8OHGuano", "pTDP43", "ApoE4", + "PSD95", "TH", "HistoneH3Lyo", "CD47", "Parvalbumin", + "Amyloidbeta142", "Calbindin", "PanApoE2E3E4", "empty139", + "CD31", "MCT1", "MBP", "SERT", "PHF1Tau", "VGAT", + "VGLUT2", "CD56Lyo", "MFN2"], + "composite_names": [ + "amyloid", + "astrocyte", + "microglia", + ], + "ez_mask_suffixes": [ + "amyloid-plaques", + "astrocyte-arms", + "microglia-arms" + ], + "cell_table_names": [ + "filtered_amyloid-plaques_table_arcsinh_transformed", + "filtered_amyloid-plaques_table_size_normalized", + "filtered_astrocyte-arms_merged_table_arcsinh_transformed", + "filtered_astrocyte-arms_merged_table_size_normalized", + "filtered_microglia-arms_merged_table_size_normalized", + "filtered_microglia-arms_merged_table_arcsinh_transformed", + "filtered_whole_cell_table_size_normalized", + "filtered_whole_cell_table_arcsinh_transformed", + "cell_and_objects_table_size_normalized", + "cell_and_objects_table_arcsinh_transformed" + ], + "log_names": [ + "amyloid-composite_log", + "amyloid-plaques_segmentation_log", + "astrocyte_composite_log", + "astrocyte-arms_segmentation_log", + "mask_merge_log", + "microglia_composite_log", + "microglia-arms_segmentation_log", + "test_composite_log" + ] + } + self.dataset_test_fns: dict[str, Callable] = { "image_data": self._image_data_check, "cell_table": self._cell_table_check, @@ -111,7 +158,8 @@ def _setup(self): "example_cell_output_dir": self._example_cell_output_dir_check, "spatial_lda": self._spatial_lda_output_dir_check, "post_clustering": self._post_clustering_output_dir_check, - "ome_tiff": self._ome_tiff_check + "ome_tiff": self._ome_tiff_check, + "ez_seg_data": self._ez_seg_data_check } # Mapping the datasets to their respective test functions. @@ -125,6 +173,7 @@ def _setup(self): "spatial_lda": "spatial_analysis/spatial_lda", "post_clustering": "post_clustering", "ome_tiff": "ome_tiff", + "ez_seg_data": "ez_seg_data", } def test_download_example_dataset(self, dataset_download: ExampleDataset): @@ -429,6 +478,79 @@ def _ome_tiff_check(self, dir_p: pathlib.Path): downloaded_ome_tiff_names = [f.stem for f in downloaded_ome_tiff] assert set(self._ome_tiff_files) == set(downloaded_ome_tiff_names) + def _ez_seg_data_check(self, dir_p: pathlib.Path): + """ + Checks to make sure that the correct files exist w.r.t the 'ez_seg_data' output dir + + Args: + dir_p (pathlib.Path): The directory to check. + """ + image_data = dir_p / "image_data" + composites = dir_p / "composites" + cell_tables = dir_p / "cell_table" + deepcell_output = dir_p / "segmentation" / "deepcell_output" + ez_masks = dir_p / "segmentation" / "ez_masks" + merged_masks = dir_p / "segmentation" / "merged_masks_dir" + mantis_visualization = dir_p / "mantis_visualization" + logs = dir_p / "logs" + + # image_data check + downloaded_fovs = list(image_data.glob("*")) + downloaded_fov_names = [f.stem for f in downloaded_fovs] + assert set(self._ez_seg_files["fov_names"]) == set(downloaded_fov_names) + + for fov in downloaded_fovs: + c_names = [c.stem for c in fov.rglob("*")] + assert set(self._ez_seg_files["channel_names"]) == set(c_names) + + # composites check + downloaded_fovs = list(composites.glob("*")) + downloaded_fov_names = [f.stem for f in downloaded_fovs] + assert set(self._ez_seg_files["fov_names"]) == set(downloaded_fov_names) + + for fov in downloaded_fovs: + c_names = [c.stem for c in fov.rglob("*")] + assert set(self._ez_seg_files["composite_names"]) == set(c_names) + + # cell tables check + downloaded_cell_tables = list(cell_tables.glob("*.csv")) + downloaded_cell_table_names = [f.stem for f in downloaded_cell_tables] + assert set(self._ez_seg_files["cell_table_names"]) == set(downloaded_cell_table_names) + + # deepcell output check + downloaded_whole_cell_seg = list(deepcell_output.glob("*.tiff")) + downloaded_whole_cell_names = [f.stem for f in downloaded_whole_cell_seg] + actual_whole_cell_names = [f"{fov}_whole_cell" for fov in self._ez_seg_files["fov_names"]] + assert set(actual_whole_cell_names) == set(downloaded_whole_cell_names) + + # ezSegmenter masks check + downloaded_ez = list(ez_masks.glob("*.tiff")) + downloaded_ez_names = [f.stem for f in downloaded_ez] + actual_ez_names = [ + f"{fov}_{ez_suffix}" + for fov in self._ez_seg_files["fov_names"] + for ez_suffix in self._ez_seg_files["ez_mask_suffixes"] + ] + assert set(actual_ez_names) == set(downloaded_ez_names) + + # merged masks check + downloaded_merged = list(merged_masks.glob("*.tiff")) + downloaded_merged_names = [f.stem for f in downloaded_merged] + actual_merged_names = [ + f"{fov}_{ez_suffix}_merged" + for fov in self._ez_seg_files["fov_names"] + for ez_suffix in self._ez_seg_files["ez_mask_suffixes"] + if ez_suffix != "amyloid-plaques" + ] + [ + f"{fov}_final_cells_remaining" for fov in self._ez_seg_files["fov_names"] + ] + assert set(actual_merged_names) == set(downloaded_merged_names) + + # logs check + downloaded_logs = list(logs.glob("*.txt")) + downloaded_log_names = [f.stem for f in downloaded_logs] + assert set(self._ez_seg_files["log_names"]) == set(downloaded_log_names) + def _suffix_paths(self, dataset_download: ExampleDataset, parent_dir: pathlib.Path) -> Generator: """ diff --git a/tests/utils/notebooks_test.py b/tests/utils/notebooks_test.py index 044277fea..4a2946387 100644 --- a/tests/utils/notebooks_test.py +++ b/tests/utils/notebooks_test.py @@ -167,6 +167,30 @@ def nb4_context( shutil.rmtree(base_dir_generator) +@pytest.fixture(scope="class") +def ez_seg_context( + templates_dir, base_dir_generator +) -> Tuple[Iterator[TestbookNotebookClient], pathlib.Path]: + """ + Creates a testbook context manager for the ezSegmenter notebook. + + Args: + templates_dir (pytest.Fixture): The fixture which yields the directory of the notebook + templates + base_dir_generator (pytest.Fixture): The fixture which yields the temporary directory + to store all notebook input / output. + + Yields: + Iterator[Tuple[Iterator[TestbookNotebookClient], pathlib.Path]]: + The testbook notebook client context manager and the temporary directory where the + notebook input / output is stored. + """ + POST_CLUSTERING: pathlib.Path = templates_dir / "ez_segmenter.ipynb" + with testbook(POST_CLUSTERING, timeout=6000, execute=False) as nb_context_manager: + yield nb_context_manager, base_dir_generator / "ez_seg" + shutil.rmtree(base_dir_generator) + + @pytest.fixture(scope="class") def nbfib_seg_context( templates_dir, base_dir_generator @@ -716,6 +740,110 @@ def test_cell_table_threshold(self): self.tb.execute_cell("cell_table_threshold") +class Test_EZSegmenter: + """ + Tests ezSegmenter notebook for completion. + NOTE: When modifying the tests, make sure the test are in the + same order as the tagged cells in the notebook. + """ + + @pytest.fixture(autouse=True, scope="function") + def _setup(self, ez_seg_context, dataset_cache_dir: Union[str, None]): + """ + Sets up necessary data and paths to run the notebooks. + """ + self.tb: testbook = ez_seg_context[0] + self.dataset: str = "ez_seg_data" + self.base_dir: str = ez_seg_context[1].as_posix() + self.cache_dir = dataset_cache_dir + + def test_imports(self): + self.tb.execute_cell("import") + + def test_base_dir(self): + base_dir_inject = f""" + base_dir = r"{self.base_dir}" + """ + self.tb.inject(base_dir_inject, "base_dir") + + def test_ex_data_download(self): + notebooks_test_utils._ex_dataset_download(dataset=self.dataset, save_dir=self.base_dir, + cache_dir=self.cache_dir) + base_dir_subpath_inject = f""" + base_dir = os.path.join(base_dir, "ez_seg_data") + """ + self.tb.inject(base_dir_subpath_inject, "ex_data_download") + + def test_file_path(self): + self.tb.execute_cell("file_path") + + def test_create_dirs(self): + self.tb.execute_cell("create_dirs") + + def test_validate_path(self): + self.tb.execute_cell("validate_path") + + def test_load_fovs(self): + load_fovs_inject = """ + fovs = ["fov0", "fov1"] + """ + self.tb.inject(load_fovs_inject, "load_fovs") + + def test_composite_set(self): + self.tb.execute_cell("composite_set") + + def test_composite_build(self): + self.tb.execute_cell("composite_build") + + def test_display_composite(self): + self.tb.execute_cell("display_composite") + + def test_display_channel(self): + self.tb.execute_cell("display_channel") + + def test_set_ez_seg_params(self): + self.tb.execute_cell("set_ez_seg_params") + + def test_gen_obj_masks(self): + self.tb.execute_cell("gen_obj_masks") + + def test_view_obj_mask(self): + self.tb.execute_cell("view_obj_mask") + + def test_set_mask_dirs(self): + self.tb.execute_cell("set_mask_dirs") + + def test_validate_mask_dirs(self): + self.tb.execute_cell("validate_mask_dirs") + + def test_merge_seg_cell(self): + self.tb.execute_cell("merge_seg_cell") + + def test_view_merged_mask(self): + self.tb.execute_cell("view_merged_mask") + + def test_consolidate_mask(self): + self.tb.execute_cell("consolidate_mask") + + def test_relabel_mask(self): + self.tb.execute_cell("relabel_mask") + + def test_nuc_props_set(self): + self.tb.execute_cell("nuc_props_set") + + def test_create_exp_mat(self): + self.tb.execute_cell("create_exp_mat") + + def test_save_exp_mat(self): + self.tb.execute_cell("save_exp_mat") + + def test_save_mat_by_mask(self): + self.tb.execute_cell("save_mat_by_mask") + + def test_mantis_mask(self): + self.tb.execute_cell("mantis_mask") + + class Test_Fiber_Segmentation: """ Tests Example Fiber Segmentation for completion. diff --git a/tests/utils/notebooks_test_utils.py b/tests/utils/notebooks_test_utils.py index 3cfc220d9..64fb78843 100644 --- a/tests/utils/notebooks_test_utils.py +++ b/tests/utils/notebooks_test_utils.py @@ -74,7 +74,13 @@ def generate_sample_feature_tifs(fovs, deepcell_output_dir, img_shape=(50, 50)): def _ex_dataset_download(dataset: str, save_dir: str, cache_dir: Union[str, None]): + """Downloads the example dataset and moves it to the save_dir. + Args: + dataset (str): The name of the dataset to download. + save_dir (str): The directory to save the dataset to. + cache_dir (Union[str, None]): The directory to cache the dataset to. + """ overwrite_existing = True ex_dataset = example_dataset.ExampleDataset(dataset=dataset, overwrite_existing=overwrite_existing,