diff --git a/tests/test_patch_extraction.py b/tests/test_patch_extraction.py index 1f00bcf92..f72216f12 100644 --- a/tests/test_patch_extraction.py +++ b/tests/test_patch_extraction.py @@ -83,6 +83,7 @@ def test_get_patch_extractor(source_image, patch_extr_csv): ) assert isinstance(points, patchextraction.PointsPatchExtractor) + assert len(points) == 1860 sliding_window = patchextraction.get_patch_extractor( input_img=input_img, @@ -424,97 +425,62 @@ def test_filter_coordinates(): mask = np.zeros([9, 6]) mask[0:4, 3:8] = 1 # will flag first 2 mask_reader = VirtualWSIReader(mask) - - # Tests for (original) filter_coordinates method - # functionality test - flag_list = PatchExtractor.filter_coordinates( - mask_reader, bbox_list, resolution=1.0, units="baseline" - ) - assert np.sum(flag_list - np.array([1, 1, 0, 0, 0, 0])) == 0 - - # Test for bad mask input - with pytest.raises( - ValueError, match="`mask_reader` should be wsireader.VirtualWSIReader." - ): - PatchExtractor.filter_coordinates( - mask, bbox_list, resolution=1.0, units="baseline" - ) - - # Test for bad bbox coordinate list in the input - with pytest.raises(ValueError, match=r".*should be ndarray of integer type.*"): - PatchExtractor.filter_coordinates( - mask_reader, bbox_list.tolist(), resolution=1.0, units="baseline" - ) - - # Test for incomplete coordinate list - with pytest.raises(ValueError, match=r".*`coordinates_list` of shape.*"): - PatchExtractor.filter_coordinates( - mask_reader, bbox_list[:, :2], resolution=1.0, units="baseline" - ) + slide_shape = [6, 9] # slide shape (w, h) at requested resolution ############################################### - # Tests for filter_coordinates_fast (new) method + # Tests for filter_coordinates (new) method _info = mask_reader.info _info.mpp = 1.0 mask_reader._m_info = _info # functionality test - flag_list = PatchExtractor.filter_coordinates_fast( + flag_list = PatchExtractor.filter_coordinates( mask_reader, bbox_list, - coordinate_resolution=1.0, - coordinate_units="mpp", - mask_resolution=1, + slide_shape, ) assert np.sum(flag_list - np.array([1, 1, 0, 0, 0, 0])) == 0 - flag_list = PatchExtractor.filter_coordinates_fast( - mask_reader, bbox_list, coordinate_resolution=(1.0, 1.0), coordinate_units="mpp" - ) + flag_list = PatchExtractor.filter_coordinates(mask_reader, bbox_list, slide_shape) # Test for bad mask input with pytest.raises( ValueError, match="`mask_reader` should be wsireader.VirtualWSIReader." ): - PatchExtractor.filter_coordinates_fast( + PatchExtractor.filter_coordinates( mask, bbox_list, - coordinate_resolution=1.0, - coordinate_units="mpp", + slide_shape, ) # Test for bad bbox coordinate list in the input with pytest.raises(ValueError, match=r".*should be ndarray of integer type.*"): - PatchExtractor.filter_coordinates_fast( + PatchExtractor.filter_coordinates( mask_reader, bbox_list.tolist(), - coordinate_resolution=1, - coordinate_units="mpp", + slide_shape, ) # Test for incomplete coordinate list with pytest.raises(ValueError, match=r".*`coordinates_list` must be of shape.*"): - PatchExtractor.filter_coordinates_fast( + PatchExtractor.filter_coordinates( mask_reader, bbox_list[:, :2], - coordinate_resolution=1, - coordinate_units="mpp", + slide_shape, ) # Test for put of range min_mask_ratio with pytest.raises(ValueError, match="`min_mask_ratio` must be between 0 and 1."): - PatchExtractor.filter_coordinates_fast( + PatchExtractor.filter_coordinates( mask_reader, bbox_list, - coordinate_resolution=1.0, - coordinate_units="mpp", + slide_shape, min_mask_ratio=-0.5, ) with pytest.raises(ValueError, match="`min_mask_ratio` must be between 0 and 1."): - PatchExtractor.filter_coordinates_fast( + PatchExtractor.filter_coordinates( mask_reader, bbox_list, - coordinate_resolution=1.0, - coordinate_units="mpp", + slide_shape, min_mask_ratio=1.1, ) @@ -529,7 +495,7 @@ def test_mask_based_patch_extractor_ndpi(sample_ndpi): # Generating a test mask to read patches from mask_dim = (int(slide_dimensions[0] / 10), int(slide_dimensions[1] / 10)) - wsi_mask = np.zeros(mask_dim, dtype=np.uint8) + wsi_mask = np.zeros(mask_dim[::-1], dtype=np.uint8) # reverse as dims are (w, h) # masking two column to extract patch from wsi_mask[:, :2] = 255 diff --git a/tiatoolbox/models/dataset/classification.py b/tiatoolbox/models/dataset/classification.py index 9b3e804bb..17ecdac3f 100644 --- a/tiatoolbox/models/dataset/classification.py +++ b/tiatoolbox/models/dataset/classification.py @@ -142,13 +142,8 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC): units: See (:class:`.WSIReader`) for details. preproc_func: - Preprocessing function used to transform the input data. If - supplied, then torch.Compose will be used on the input - preprocs. preprocs is a list of torchvision transforms for - preprocessing the image. The transforms will be applied in - the order that they are given in the list. For more - information, visit the following link: - https://pytorch.org/vision/stable/transforms.html. + Preprocessing function used to transform the input data. It will + be called on each patch before returning it. """ @@ -162,6 +157,8 @@ def __init__( resolution=None, units=None, auto_get_mask=True, + min_mask_ratio=0, + preproc_func=None, ): """Create a WSI-level patch dataset. @@ -187,12 +184,22 @@ def __init__( `units`. Expected to be positive and of (height, width). Note, this is not at level 0. resolution: - Check (:class:`.WSIReader`) for details. When - `mode='tile'`, value is fixed to be `resolution=1.0` and - `units='baseline'` units: check (:class:`.WSIReader`) for - details. + Check (:class:`.WSIReader`) for details. When + `mode='tile'`, value is fixed to be `resolution=1.0` and + `units='baseline'` units: check (:class:`.WSIReader`) for + details. + units: + Units in which `resolution` is defined. + auto_get_mask: + If `True`, then automatically get simple threshold mask using + WSIReader.tissue_mask() function. + min_mask_ratio: + Only patches with positive area percentage above this value are + included. Defaults to 0. preproc_func: - Preprocessing function used to transform the input data. + Preprocessing function used to transform the input data. If + supplied, the function will be called on each patch before + returning it. Examples: >>> # A user defined preproc func and expected behavior @@ -233,6 +240,7 @@ def __init__( ): raise ValueError(f"Invalid `stride_shape` value {stride_shape}.") + self.preproc_func = preproc_func img_path = pathlib.Path(img_path) if mode == "wsi": self.reader = WSIReader.open(img_path) @@ -299,8 +307,8 @@ def __init__( selected = PatchExtractor.filter_coordinates( mask_reader, # must be at the same resolution self.inputs, # must already be at requested resolution - resolution=resolution, - units=units, + wsi_shape=wsi_shape, + min_mask_ratio=min_mask_ratio, ) self.inputs = self.inputs[selected] diff --git a/tiatoolbox/tools/patchextraction.py b/tiatoolbox/tools/patchextraction.py index 1bcf2f044..540d56c65 100644 --- a/tiatoolbox/tools/patchextraction.py +++ b/tiatoolbox/tools/patchextraction.py @@ -158,6 +158,9 @@ def __iter__(self): self.n = 0 return self + def __len__(self): + return self.locations_df.shape[0] if self.locations_df is not None else 0 + def __next__(self): n = self.n @@ -202,21 +205,10 @@ def _generate_location_df(self): ) if self.mask is not None: - # convert the coordinate_list resolution unit to acceptable units - converted_units = self.wsi.convert_resolution_units( - input_res=self.resolution, - input_unit=self.units, - ) - # find the first unit which is not None - converted_units = { - k: v for k, v in converted_units.items() if v is not None - } - converted_units_keys = list(converted_units.keys()) - selected_coord_indices = self.filter_coordinates_fast( + selected_coord_indices = self.filter_coordinates( self.mask, self.coordinate_list, - coordinate_resolution=converted_units[converted_units_keys[0]], - coordinate_units=converted_units_keys[0], + wsi_shape=slide_dimension, min_mask_ratio=self.min_mask_ratio, ) self.coordinate_list = self.coordinate_list[selected_coord_indices] @@ -234,13 +226,12 @@ def _generate_location_df(self): return self @staticmethod - def filter_coordinates_fast( + def filter_coordinates( mask_reader: wsireader.VirtualWSIReader, coordinates_list: np.ndarray, - coordinate_resolution: float, - coordinate_units: str, - mask_resolution: float = None, + wsi_shape: Tuple[int, int], min_mask_ratio: float = 0, + func: Callable = None, ): """Validate patch extraction coordinates based on the input mask. @@ -260,19 +251,19 @@ def filter_coordinates_fast( default `func=None`, K should be 4, as we expect the `coordinates_list` to be bounding boxes in `[start_x, start_y, end_x, end_y]` format. - coordinate_resolution (float): - Resolution value at which `coordinates_list` is - generated. - coordinate_units (str): - Resolution unit at which `coordinates_list` is generated. - mask_resolution (float): - Resolution at which mask array is extracted. It is - supposed to be in the same units as `coord_resolution` - i.e., `coordinate_units`. If not provided, a default - value will be selected based on `coordinate_units`. + wsi_shape (tuple(int, int)): + Shape of the WSI in the requested `resolution` and `units`. min_mask_ratio (float): Only patches with positive area percentage above this value are - included. Defaults to 0. + included. Defaults to 0. Has no effect if `func` is not `None`. + func (callable): + Function to be used to validate the coordinates. The function + must take a `numpy.ndarray` of the mask and a `numpy.ndarray` + of the coordinates as input and return a bool indicating + whether the coordinate is valid or not. If `None`, a default + function that accepts patches with positive area proportion above + `min_mask_ratio` is used. + Returns: :class:`numpy.ndarray`: @@ -287,113 +278,46 @@ def filter_coordinates_fast( raise ValueError("`coordinates_list` should be ndarray of integer type.") if coordinates_list.shape[-1] != 4: raise ValueError("`coordinates_list` must be of shape [N, 4].") - if isinstance(coordinate_resolution, (int, float)): - coordinate_resolution = [coordinate_resolution, coordinate_resolution] if not 0 <= min_mask_ratio <= 1: raise ValueError("`min_mask_ratio` must be between 0 and 1.") - # define default mask_resolution based on the input `coordinate_units` - if mask_resolution is None: - mask_res_dict = {"mpp": 8, "power": 1.25, "baseline": 0.03125} - mask_resolution = mask_res_dict[coordinate_units] - - tissue_mask = mask_reader.slide_thumbnail( - resolution=mask_resolution, units=coordinate_units - ) + # the tissue mask exists in the reader already, no need to generate it + tissue_mask = mask_reader.img # Scaling the coordinates_list to the `tissue_mask` array resolution + scale_factors = np.array(tissue_mask.shape[::-1]) / np.array(wsi_shape) scaled_coords = coordinates_list.copy().astype(np.float32) - scaled_coords[:, [0, 2]] *= coordinate_resolution[0] / mask_resolution + scaled_coords[:, [0, 2]] *= scale_factors[0] scaled_coords[:, [0, 2]] = np.clip( scaled_coords[:, [0, 2]], 0, tissue_mask.shape[1] ) - scaled_coords[:, [1, 3]] *= coordinate_resolution[1] / mask_resolution + scaled_coords[:, [1, 3]] *= scale_factors[1] scaled_coords[:, [1, 3]] = np.clip( scaled_coords[:, [1, 3]], 0, tissue_mask.shape[0] ) scaled_coords = list(np.int32(scaled_coords)) - flag_list = [] - for coord in scaled_coords: + def default_sel_func(tissue_mask, coord): + """Default selection function to filter coordinates. + + This function selects a coordinate if the proportion of + positive mask in the corresponding patch is greater than + `min_mask_ratio`. + + """ this_part = tissue_mask[coord[1] : coord[3], coord[0] : coord[2]] patch_area = np.prod(this_part.shape) pos_area = np.count_nonzero(this_part) - - if ( + return ( (pos_area == patch_area) or (pos_area > patch_area * min_mask_ratio) - ) and (pos_area > 0 and patch_area > 0): - flag_list.append(True) - else: - flag_list.append(False) - return np.array(flag_list) - - @staticmethod - def filter_coordinates( - mask_reader: wsireader.VirtualWSIReader, - coordinates_list: np.ndarray, - func: Callable = None, - resolution: float = None, - units: str = None, - ): - """Indicates which coordinate is valid for mask-based patch extraction. - - Locations are validated by a custom or default filter `func`. - - Args: - mask_reader (:class:`.VirtualReader`): - A virtual pyramidal reader of the mask related to the - WSI from which we want to extract the patches. - coordinates_list (ndarray and np.int32): - Coordinates to be checked via the `func`. They must be - in the same resolution as requested `resolution` and - `units`. The shape of `coordinates_list` is (N, K) where - N is the number of coordinate sets and K is either 2 for - centroids or 4 for bounding boxes. When using the - default `func=None`, K should be 4, as we expect the - `coordinates_list` to refer to bounding boxes in - `[start_x, start_y, end_x, end_y]` format. - func: - The coordinate validator function. A function that takes - `reader` and `coordinate` as arguments and return True - or False as indication of coordinate validity. - resolution (float): - The resolution value at which coordinates_list are - generated. - units (str): - The resolution unit at which coordinates_list are - generated. - - Returns: - :class:`numpy.ndarray`: - List of flags to indicate which coordinates are valid. - - """ + ) and (pos_area > 0 and patch_area > 0) - def default_sel_func(reader: wsireader.VirtualWSIReader, coord: np.ndarray): - """Accept coord as long as its box contains bits of mask.""" - roi = reader.read_bounds( - coord, - resolution=reader.info.mpp if resolution is None else resolution, - units="mpp" if units is None else units, - interpolation="nearest", - coord_space="resolution", - ) - return np.sum(roi > 0) > 0 - - if not isinstance(mask_reader, wsireader.VirtualWSIReader): - raise ValueError("`mask_reader` should be wsireader.VirtualWSIReader.") - if not isinstance(coordinates_list, np.ndarray) or not np.issubdtype( - coordinates_list.dtype, np.integer - ): - raise ValueError("`coordinates_list` should be ndarray of integer type.") - if func is None and coordinates_list.shape[-1] != 4: - raise ValueError( - f"Default `func` does not support " - f"`coordinates_list` of shape {coordinates_list.shape}." - ) func = default_sel_func if func is None else func - flag_list = [func(mask_reader, coord) for coord in coordinates_list] + flag_list = [] + for coord in scaled_coords: + flag_list.append(func(tissue_mask, coord)) + return np.array(flag_list) @staticmethod diff --git a/tiatoolbox/tools/registration/wsi_registration.py b/tiatoolbox/tools/registration/wsi_registration.py index 4eb80d5c0..54fb48f1a 100644 --- a/tiatoolbox/tools/registration/wsi_registration.py +++ b/tiatoolbox/tools/registration/wsi_registration.py @@ -716,8 +716,8 @@ def find_points_inside_boundary(mask: np.ndarray, points: np.ndarray) -> np.ndar # convert coordinates of shape [N, 2] to [N, 4] end_x_y = points[:, 0:2] + 1 bbox_coord = np.c_[points, end_x_y].astype(int) - return PatchExtractor.filter_coordinates_fast( - mask_reader, bbox_coord, 1.0, "baseline", 1.0 + return PatchExtractor.filter_coordinates( + mask_reader, bbox_coord, mask.shape[::-1] ) def filtering_matching_points( diff --git a/tiatoolbox/wsicore/wsireader.py b/tiatoolbox/wsicore/wsireader.py index fc1a2e527..62d671e45 100644 --- a/tiatoolbox/wsicore/wsireader.py +++ b/tiatoolbox/wsicore/wsireader.py @@ -3881,7 +3881,7 @@ def _info(self) -> WSIMeta: ) for level in self.wsi.levels ] - dataset = self.wsi.base_level.datasets[0] + dataset = self.wsi.levels.base_level.datasets[0] # Get pixel spacing in mm from DICOM file and convert to um/px (mpp) mm_per_pixel = dataset.pixel_spacing mpp = (mm_per_pixel.width * 1e3, mm_per_pixel.height * 1e3)