diff --git a/tests/test_wsi_registration.py b/tests/test_wsi_registration.py index c1062ad61..d84bc67a7 100644 --- a/tests/test_wsi_registration.py +++ b/tests/test_wsi_registration.py @@ -5,6 +5,7 @@ import pytest from tiatoolbox.tools.registration.wsi_registration import ( + AffineWSITransformer, DFBRegister, apply_bspline_transform, estimate_bspline_transform, @@ -13,6 +14,7 @@ ) from tiatoolbox.utils.metrics import dice from tiatoolbox.utils.misc import imread +from tiatoolbox.wsicore.wsireader import WSIReader def test_extract_features(dfbr_features): @@ -338,7 +340,7 @@ def test_register_output_with_initializer( pre_transform = np.array([[-1, 0, 337.8], [0, -1, 767.7], [0, 0, 1]]) expected = np.array( - [[-0.99683, -0.00333, 338.69983], [-0.03201, -0.98420, 770.22941], [0, 0, 1]] + [[-0.98454, -0.00708, 397.95628], [-0.01024, -0.99752, 684.81131], [0, 0, 1]] ) output = df.register( @@ -363,7 +365,7 @@ def test_register_output_without_initializer( df = DFBRegister() expected = np.array( - [[-0.99683, -0.00189, 336.79039], [0.00691, -0.99810, 765.98081], [0, 0, 1]] + [[-0.99863, 0.00189, 389.79039], [0.00691, -0.99810, 874.98081], [0, 0, 1]] ) output = df.register( @@ -462,3 +464,29 @@ def test_bspline_transform(fixed_image, moving_image, fixed_mask, moving_mask): registered_msk = apply_bspline_transform(fixed_msk, moving_msk, transform) mask_overlap = dice(fixed_msk, registered_msk) assert mask_overlap > 0.75 + + +def test_affine_wsi_transformer(sample_ome_tiff): + test_locations = [(1001, 600), (1000, 500), (800, 701)] # at base level 0 + resolution = 0 + size = (100, 100) + + for location in test_locations: + wsi_reader = WSIReader.open(input_img=sample_ome_tiff) + expected = wsi_reader.read_rect( + location, size, resolution=resolution, units="level" + ) + + transform_level0 = np.array( + [ + [0, -1, location[0] + location[1] + size[1]], + [1, 0, location[1] - location[0]], + [0, 0, 1], + ] + ) + tfm = AffineWSITransformer(wsi_reader, transform_level0) + output = tfm.read_rect(location, size, resolution=resolution, units="level") + + expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE) + + assert np.sum(expected - output) == 0 diff --git a/tiatoolbox/annotation/storage.py b/tiatoolbox/annotation/storage.py index d47e0c622..6109da823 100644 --- a/tiatoolbox/annotation/storage.py +++ b/tiatoolbox/annotation/storage.py @@ -1129,6 +1129,7 @@ def from_geojson( origin: Tuple[float, float] = (0, 0), ) -> "AnnotationStore": """Create a new database with annotations loaded from a geoJSON file. + Args: fp (Union[IO, str, Path]): The file path or handle to load from. @@ -1461,10 +1462,10 @@ class SQLiteStore(AnnotationStore): Uses and rtree index for fast spatial queries. Version History: - 1.0.0: - Initial version. - 1.0.1 (07/10/2022): - Added optional "area" column and queries sorted/filtered by area. + 1.0.0: + Initial version. + 1.0.1 (07/10/2022): + Added optional "area" column and queries sorted/filtered by area. """ diff --git a/tiatoolbox/tools/registration/wsi_registration.py b/tiatoolbox/tools/registration/wsi_registration.py index 90cab377e..a9977e262 100644 --- a/tiatoolbox/tools/registration/wsi_registration.py +++ b/tiatoolbox/tools/registration/wsi_registration.py @@ -1,18 +1,25 @@ +import itertools import warnings -from typing import Dict, Tuple +from numbers import Number +from typing import Dict, Tuple, Union import cv2 import numpy as np import SimpleITK as sitk # noqa: N813 import torch import torchvision +from numpy.linalg import inv from skimage import exposure, filters +from skimage.registration import phase_cross_correlation from skimage.util import img_as_float from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils.metrics import dice from tiatoolbox.utils.transforms import imresize -from tiatoolbox.wsicore.wsireader import VirtualWSIReader +from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader + +Resolution = Union[Number, Tuple[Number, Number], np.ndarray] +IntBounds = Tuple[int, int, int, int] def _check_dims( @@ -38,8 +45,8 @@ def _check_dims( Returns: tuple: - - :class:`numpy.ndarray`: A grayscale fixed image. - - :class:`numpy.ndarray`: A grayscale moving image. + - :class:`numpy.ndarray` - A grayscale fixed image. + - :class:`numpy.ndarray` - A grayscale moving image. """ if len(np.unique(fixed_mask)) == 1 or len(np.unique(moving_mask)) == 1: @@ -60,7 +67,7 @@ def _check_dims( return fixed_img, moving_img -def compute_center_of_mass(mask: np.ndarray) -> list: +def compute_center_of_mass(mask: np.ndarray) -> tuple: """Compute center of mass. Args: @@ -68,14 +75,15 @@ def compute_center_of_mass(mask: np.ndarray) -> list: A binary mask. Returns: - list: - x- and y- coordinates representing center of mass. + :py:obj:`tuple` - x- and y- coordinates representing center of mass. + - :py:obj:`int` - X coordinate. + - :py:obj:`int` - Y coordinate. """ moments = cv2.moments(mask) x_coord_center = moments["m10"] / moments["m00"] y_coord_center = moments["m01"] / moments["m00"] - return [x_coord_center, y_coord_center] + return (x_coord_center, y_coord_center) def prealignment( @@ -108,10 +116,16 @@ def prealignment( Returns: tuple: - - class:`numpy.ndarray`: A rigid transform matrix. - - class:`numpy.ndarray`: Transformed moving image. - - class:`numpy.ndarray`: Transformed moving mask. - - float: Dice overlap + - :class:`numpy.ndarray` - A rigid transform matrix. + - :class:`numpy.ndarray` - Transformed moving image. + - :class:`numpy.ndarray` - Transformed moving mask. + - :py:obj:`float` - Dice overlap. + + Examples: + >>> from tiatoolbox.tools.registration.wsi_registration import prealignment + >>> transform, transformed_image, transformed_mask, dice_overlap = prealignment( + ... fixed_thumbnail, moving_thumbnail, fixed_mask, moving_mask + ... ) """ orig_fixed_img, orig_moving_img = fixed_img, moving_img @@ -234,6 +248,10 @@ def match_histograms( - :class:`numpy.ndarray` - A normalized grayscale image. - :class:`numpy.ndarray` - A normalized grayscale image. + Examples: + >>> from tiatoolbox.tools.registration.wsi_registration import match_histograms + >>> norm_image_a, norm_image_b = match_histograms(gray_image_a, gray_image_b) + """ image_a, image_b = np.squeeze(image_a), np.squeeze(image_b) if len(image_a.shape) == 3 or len(image_b.shape) == 3: @@ -346,6 +364,17 @@ class DFBRegister: sensing image registration using deep convolutional features. Ieee Access, 6, pp.38544-38555. + Examples: + >>> from tiatoolbox.tools.registration.wsi_registration import DFBRegister + >>> import cv2 + >>> df = DFBRegister() + >>> fixed_image = np.repeat(np.expand_dims(fixed_gray, axis=2), 3, axis=2) + >>> moving_image = np.repeat(np.expand_dims(moving_gray, axis=2), 3, axis=2) + >>> transform = df.register(fixed_image, moving_image, fixed_mask, moving_mask) + >>> registered = cv2.warpAffine( + ... moving_gray, transform[0:-1], fixed_gray.shape[:2][::-1] + ... ) + """ def __init__(self, patch_size: Tuple[int, int] = (224, 224)): @@ -599,7 +628,7 @@ def get_tissue_regions( fixed_mask: np.ndarray, moving_image: np.ndarray, moving_mask: np.ndarray, - ) -> Tuple[np.array, np.array, np.array, np.array, tuple]: + ) -> Tuple[np.array, np.array, np.array, np.array, IntBounds]: """Extract tissue region. This function uses binary mask for extracting tissue @@ -617,11 +646,19 @@ def get_tissue_regions( Returns: tuple: - - np.ndarray - A cropped image containing tissue region. - - np.ndarray - A cropped image containing tissue mask. - - np.ndarray - A cropped image containing tissue region. - - np.ndarray - A cropped image containing tissue mask. - - tuple - Bounding box (min_row, min_col, max_row, max_col). + - :class:`numpy.ndarray` - A cropped image containing tissue region + from fixed image. + - :class:`numpy.ndarray` - A cropped image containing tissue mask + from fixed image. + - :class:`numpy.ndarray` - A cropped image containing tissue region + from moving image. + - :class:`numpy.ndarray` - A cropped image containing tissue mask + from moving image. + - :py:obj:`tuple` - Bounds of the tissue region. + - :py:obj:`int` - Top (start y value) + - :py:obj:`int` - Left (start x value) + - :py:obj:`int` - Bottom (end y value) + - :py:obj:`int` - Right (end x value) """ fixed_minc, fixed_min_r, width, height = cv2.boundingRect(fixed_mask) @@ -654,7 +691,7 @@ def get_tissue_regions( ) @staticmethod - def find_points_inside_boundary(mask: np.ndarray, points: np.ndarray): + def find_points_inside_boundary(mask: np.ndarray, points: np.ndarray) -> np.ndarray: """Find indices of points lying inside the boundary. This function returns indices of points which are @@ -1011,7 +1048,7 @@ def register( fixed_tissue_img, moving_tissue_img, fixed_tissue_mask, moving_tissue_mask ) - # Use tissue transform if it improves DICE overlap + # Use the estimated transform only if it improves DICE overlap after_dice = dice(fixed_tissue_mask, transform_tissue_mask) if after_dice > before_dice: moving_tissue_img, moving_tissue_mask = ( @@ -1022,7 +1059,7 @@ def register( else: tissue_transform = np.eye(3, 3) - # Perform transform using tissue regions in a block wise manner + # Perform transform using tissue regions in a block-wise manner ( block_transform, transform_tissue_img, @@ -1031,13 +1068,25 @@ def register( fixed_tissue_img, moving_tissue_img, fixed_tissue_mask, moving_tissue_mask ) - # Use block-wise tissue transform if it improves DICE overlap + # Use the estimated tissue transform only if it improves DICE overlap after_dice = dice(fixed_tissue_mask, transform_tissue_mask) - if after_dice <= before_dice: + if after_dice > before_dice: + moving_tissue_img, moving_tissue_mask = ( + transform_tissue_img, + transform_tissue_mask, + ) + before_dice = after_dice + else: block_transform = np.eye(3, 3) + # Fix translation offset + shift, _error, _diff_phase = phase_cross_correlation( + fixed_tissue_img, moving_tissue_img + ) + translation_offset = np.array([[1, 0, shift[1]], [0, 1, shift[0]], [0, 0, 1]]) + # Combining tissue and block transform - tissue_transform = block_transform @ tissue_transform + tissue_transform = translation_offset @ block_transform @ tissue_transform # tissue_transform is computed for cropped images (tissue region only). # It is converted using the tissue crop coordinates, so that it can be @@ -1112,6 +1161,18 @@ def estimate_bspline_transform( Returns: 2D deformation transformation represented by a grid of control points. + Examples: + >>> from tiatoolbox.tools.registration.wsi_registration import ( + ... estimate_bspline_transform, apply_bspline_transform + ... ) + >>> bspline_transform = estimate_bspline_transform( + ... fixed_gray_thumbnail, moving_gray_thumbnail, fixed_mask, moving_mask, + ... grid_space=50.0, sampling_percent=0.1, + ... ) + >>> bspline_registered_image = apply_bspline_transform( + ... fixed_thumbnail, moving_thumbnail, bspline_transform + ... ) + """ bspline_params = { "grid_space": 50.0, @@ -1233,3 +1294,223 @@ def apply_bspline_transform( sitk_registered_image_sitk = resampler.Execute(moving_image_sitk) return sitk.GetArrayFromImage(sitk_registered_image_sitk) + + +class AffineWSITransformer: + """Resampling regions from a whole slide image. + + This class is used to resample tiles/patches from a whole slide image + using transformation. + + Example: + >>> from tiatoolbox.tools.registration.wsi_registration import ( + ... AffineWSITransformer + ... ) + >>> from tiatoolbox.wsicore.wsireader import WSIReader + >>> wsi_reader = WSIReader.open(input_img=sample_ome_tiff) + >>> transform_level0 = np.eye(3) + >>> tfm = AffineWSITransformer(wsi_reader, transform_level0) + >>> output = tfm.read_rect(location, size, resolution=resolution, units="level") + + """ + + def __init__(self, reader: WSIReader, transform: np.ndarray) -> None: + """Initialize object. + + Args: + reader (WSIReader): + An object with base WSIReader as base class. + transform (:class:`numpy.ndarray`): + A 3x3 transformation matrix. The inverse transformation will be applied. + + """ + self.wsi_reader = reader + self.transform_level0 = transform + + @staticmethod + def transform_points(points: np.ndarray, transform: np.ndarray) -> np.ndarray: + """Transform points using the given transformation matrix. + + Args: + points (:class:`numpy.ndarray`): + A set of points of shape (N, 2). + transform (:class:`numpy.ndarray`): + Transformation matrix of shape (3, 3). + + Returns: + :class:`numpy.ndarray`: + Warped points of shape (N, 2). + + """ + points = np.array(points) + # Pad the data with ones, so that our transformation can do translations + points_pad = np.hstack([points, np.ones((points.shape[0], 1))]) + points_warp = np.dot(points_pad, transform.T) + return points_warp[:, :-1] + + def get_patch_dimensions( + self, size: Tuple[int, int], transform: np.ndarray + ) -> Tuple[int, int]: + """Compute patch size needed for transformation. + + Args: + size (tuple(int)): + (width, height) tuple giving the desired output image size. + transform (:class:`numpy.ndarray`): + Transformation matrix of shape (3, 3). + + Returns: + :py:obj:`tuple` - Maximum size of the patch needed for transformation. + - :py:obj:`int` - Width + - :py:obj:`int` - Height + + """ + width, height = size[0], size[1] + + x = [ + np.linspace(1, width, width, endpoint=True), + np.ones(height) * width, + np.linspace(1, width, width, endpoint=True), + np.ones(height), + ] + x = np.array(list(itertools.chain.from_iterable(x))) + + y = [ + np.ones(width), + np.linspace(1, height, height, endpoint=True), + np.ones(width) * height, + np.linspace(1, height, height, endpoint=True), + ] + y = np.array(list(itertools.chain.from_iterable(y))) + + points = np.array([x, y]).transpose() + transform_points = self.transform_points(points, transform) + + width = np.max(transform_points[:, 0]) - np.min(transform_points[:, 0]) + 1 + height = np.max(transform_points[:, 1]) - np.min(transform_points[:, 1]) + 1 + width, height = np.ceil(width).astype(int), np.ceil(height).astype(int) + + return (width, height) + + def get_transformed_location( + self, location: Tuple[int, int], size: Tuple[int, int], level: int + ) -> Tuple[int, int]: + """Get corresponding location on unregistered image and the required patch size. + + This function applies inverse transformation to the centre point of the region. + The transformed centre point is used to obtain the transformed top left pixel + of the region. + + Args: + location (tuple(int)): + (x, y) tuple giving the top left pixel in the baseline (level 0) + reference frame. + size (tuple(int)): + (width, height) tuple giving the desired output image size. + level (int): + Pyramid level/resolution layer. + + Returns: + tuple: + - :py:obj:`tuple` - Transformed location (top left pixel). + - :py:obj:`int` - X coordinate + - :py:obj:`int` - Y coordinate + - :py:obj:`tuple` - Maximum size suitable for transformation. + - :py:obj:`int` - Width + - :py:obj:`int` - Height + + """ + inv_transform = inv(self.transform_level0) + size_level0 = [x * (2**level) for x in size] + center_level0 = [x + size_level0[i] / 2 for i, x in enumerate(location)] + center_level0 = np.expand_dims(np.array(center_level0), axis=0) + center_level0 = self.transform_points(center_level0, inv_transform)[0] + + transformed_size = self.get_patch_dimensions(size, inv_transform) + transformed_location = [ + center_level0[0] - (transformed_size[0] * (2**level)) / 2, + center_level0[1] - (transformed_size[1] * (2**level)) / 2, + ] + transformed_location = tuple( + np.round(x).astype(int) for x in transformed_location + ) + return transformed_location, transformed_size + + def transform_patch(self, patch: np.ndarray, size: Tuple[int, int]) -> np.ndarray: + """Apply transformation to the given patch. + + This function applies the transformation matrix after removing the translation. + + Args: + patch (:class:`numpy.ndarray`): + A region of whole slide image. + size (tuple(int)): + (width, height) tuple giving the desired output image size. + + Returns: + :class:`numpy.ndarray`: + A transformed region/patch. + + """ + transform = self.transform_level0 * [[1, 1, 0], [1, 1, 0], [1, 1, 1]] + translation = (-size[0] / 2 + 0.5, -size[1] / 2 + 0.5) + forward_translation = np.array( + [[1, 0, translation[0]], [0, 1, translation[1]], [0, 0, 1]] + ) + inverse_translation = np.linalg.inv(forward_translation) + transform = inverse_translation @ transform @ forward_translation + return cv2.warpAffine(patch, transform[0:-1][:], patch.shape[:2][::-1]) + + def read_rect( + self, + location: Tuple[int, int], + size: Tuple[int, int], + resolution: Resolution, + units: str, + ) -> np.ndarray: + """Read a transformed region of the transformed whole slide image. + + Location is in terms of the baseline image (level 0 / maximum resolution), + and size is the output image size. + + Args: + location (tuple(int)): + (x, y) tuple giving the top left pixel in the baseline (level 0) + reference frame. + size (tuple(int)): + (width, height) tuple giving the desired output image size. + resolution (float or tuple(float)): + Pyramid level/resolution layer. + units (str): + Units of the scale. + + Returns: + :class:`numpy.ndarray`: + A transformed region/patch. + + """ + ( + read_level, + _, + _, + _post_read_scale, + _baseline_read_size, + ) = self.wsi_reader.find_read_rect_params( + location=location, + size=size, + resolution=resolution, + units=units, + ) + transformed_location, max_size = self.get_transformed_location( + location, size, read_level + ) + patch = self.wsi_reader.read_rect( + transformed_location, max_size, resolution=resolution, units=units + ) + transformed_patch = self.transform_patch(patch, max_size) + + start_row = int(max_size[1] / 2) - int(size[1] / 2) + end_row = int(max_size[1] / 2) + int(size[1] / 2) + start_col = int(max_size[0] / 2) - int(size[0] / 2) + end_col = int(max_size[0] / 2) + int(size[0] / 2) + return transformed_patch[start_row:end_row, start_col:end_col, :] diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 280b9cfc8..b5cb4b706 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -880,7 +880,7 @@ def store_from_dat( eg {1: 'Epithelial Cell', 2: 'Lymphocyte', 3: ...}. For multi-head output, should be a dict of dicts, eg: {'head1': {1: 'Epithelial Cell', 2: 'Lymphocyte', 3: ...}, - 'head2': {1: 'Gland', 2: 'Lumen', 3: ...}, ...}. + 'head2': {1: 'Gland', 2: 'Lumen', 3: ...}, ...}. origin (Tuple[float, float]): The x and y coordinates to use as the origin for the annotations. cls (AnnotationStore): @@ -998,7 +998,7 @@ def add_from_dat( ) -> None: """Add annotations from a .dat file to an existing store. - Make a best effort to create valid shapely geometries from provided contours. + Make the best effort to create valid shapely geometries from provided contours. Args: fp (Union[IO, str, Path]): @@ -1013,9 +1013,9 @@ def add_from_dat( replaced by the corresponding value. Useful for providing descriptive names to non-descriptive types, eg {1: 'Epithelial Cell', 2: 'Lymphocyte', 3: ...}. - For multi-head output, should be a dict of dicts, eg: + For multi-head output, should be a dict of dicts, e.g.: {'head1': {1: 'Epithelial Cell', 2: 'Lymphocyte', 3: ...}, - 'head2': {1: 'Gland', 2: 'Lumen', 3: ...}, ...}. + 'head2': {1: 'Gland', 2: 'Lumen', 3: ...}, ...}. origin [float, float]: The x and y coordinates to use as the origin for the annotations. diff --git a/tiatoolbox/utils/visualization.py b/tiatoolbox/utils/visualization.py index c82003615..a5fbb1408 100644 --- a/tiatoolbox/utils/visualization.py +++ b/tiatoolbox/utils/visualization.py @@ -496,48 +496,48 @@ class AnnotationRenderer: from an AnnotationStore to a tile. Args: - score_prop (str): - A key that is present in the properties of annotations - to be rendered that will be used to color rendered annotations. - mapper (str, Dict or List): - A dictionary or colormap used to color annotations according - to the value of properties[score_prop] of an annotation. Should - be either a matplotlib colormap, a string which is a name of a - matplotlib colormap, a dict of possible property {value: color} - pairs, or a list of categorical property values (in which case a - dict will be created with a random color generated for each - category) - where (str or Callable): - a callable or predicate which will be passed on to - AnnotationStore.query() when fetching annotations to be rendered - (see AnnotationStore for more details) - score_fn (Callable): - an optional callable which will be called on the value of - the property that will be used to generate the color before giving - it to colormap. Use it for example to normalise property - values if they do not fall into the range [0,1], as matplotlib - colormap expects values in this range. i.e roughly speaking - annotation_color=mapper(score_fn(ann.properties[score_prop])) - max_scale (int): - downsample level above which Polygon geometries on crowded - tiles will be rendered as a bounding box instead - zoomed_out_strat (int, str): - strategy to use when rendering zoomed out tiles at - a level above max_scale. Can be one of 'decimate', 'scale', or a number - which defines the minimum area an abject has to cover to be rendered - while zoomed out above max_scale. - thickness (int): - line thickness of rendered contours. -1 will render filled - contours. - edge_thickness (int): - line thickness of rendered edges. - secondary_cmap (dict [str, str, cmap])): - a dictionary of the form {"type": some_type, - "score_prop": a property name, "mapper": a matplotlib cmap object}. - For annotations of the specified type, the given secondary colormap - will override the primary colormap. - blur_radius (int): - radius of gaussian blur to apply to rendered annotations. + score_prop (str): + A key that is present in the properties of annotations + to be rendered that will be used to color rendered annotations. + mapper (str, Dict or List): + A dictionary or colormap used to color annotations according + to the value of properties[score_prop] of an annotation. Should + be either a matplotlib colormap, a string which is a name of a + matplotlib colormap, a dict of possible property {value: color} + pairs, or a list of categorical property values (in which case a + dict will be created with a random color generated for each + category) + where (str or Callable): + a callable or predicate which will be passed on to + AnnotationStore.query() when fetching annotations to be rendered + (see AnnotationStore for more details) + score_fn (Callable): + an optional callable which will be called on the value of + the property that will be used to generate the color before giving + it to colormap. Use it for example to normalise property + values if they do not fall into the range [0,1], as matplotlib + colormap expects values in this range. i.e roughly speaking + annotation_color=mapper(score_fn(ann.properties[score_prop])) + max_scale (int): + downsample level above which Polygon geometries on crowded + tiles will be rendered as a bounding box instead + zoomed_out_strat (int, str): + strategy to use when rendering zoomed out tiles at + a level above max_scale. Can be one of 'decimate', 'scale', or a number + which defines the minimum area an abject has to cover to be rendered + while zoomed out above max_scale. + thickness (int): + line thickness of rendered contours. -1 will render filled + contours. + edge_thickness (int): + line thickness of rendered edges. + secondary_cmap (dict [str, str, cmap])): + a dictionary of the form {"type": some_type, + "score_prop": a property name, "mapper": a matplotlib cmap object}. + For annotations of the specified type, the given secondary colormap + will override the primary colormap. + blur_radius (int): + radius of gaussian blur to apply to rendered annotations. """