diff --git a/tests/engines/test_patch_classifier.py b/tests/engines/test_patch_classifier.py new file mode 100644 index 000000000..f4ad87c28 --- /dev/null +++ b/tests/engines/test_patch_classifier.py @@ -0,0 +1,108 @@ +"""Test for Patch Classifier.""" + +from __future__ import annotations + +import shutil +from pathlib import Path + +import numpy as np +import zarr + +from tiatoolbox.models.engine.patch_classifier import PatchClassifier +from tiatoolbox.utils import env_detection as toolbox_env + +device = "cuda" if toolbox_env.has_gpu() else "cpu" + + +def _test_classifier_output( + inputs: list, + model: str, + probabilities_check: list | None = None, + classification_check: list | None = None, + tmp_path: Path | None = None, +) -> None: + """Test the predictions of multiple models included in tiatoolbox.""" + cache_mode = None if tmp_path is None else True + save_dir = None if tmp_path is None else tmp_path / "output" + classifier = PatchClassifier( + model=model, + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = classifier.run( + inputs, + return_labels=False, + device=device, + cache_mode=cache_mode, + save_dir=save_dir, + ) + + if tmp_path is not None: + output = zarr.open(output, mode="r") + + probabilities = output["probabilities"] + classification = output["predictions"] + for idx, probabilities_ in enumerate(probabilities): + probabilities_max = max(probabilities_) + assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( + model, + probabilities_max, + probabilities_check[idx], + probabilities_, + classification_check[idx], + ) + assert classification[idx] == classification_check[idx], ( + model, + probabilities_max, + probabilities_check[idx], + probabilities_, + classification_check[idx], + ) + if save_dir: + shutil.rmtree(save_dir) + + +def test_patch_predictor_kather100k_output( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test the output of patch classification models on Kather100K dataset.""" + inputs = [Path(sample_patch1), Path(sample_patch2)] + pretrained_info = { + "alexnet-kather100k": [1.0, 0.9999735355377197], + "resnet18-kather100k": [1.0, 0.9999911785125732], + "resnet34-kather100k": [1.0, 0.9979840517044067], + "resnet50-kather100k": [1.0, 0.9999986886978149], + "resnet101-kather100k": [1.0, 0.9999932050704956], + "resnext50_32x4d-kather100k": [1.0, 0.9910059571266174], + "resnext101_32x8d-kather100k": [1.0, 0.9999971389770508], + "wide_resnet50_2-kather100k": [1.0, 0.9953408241271973], + "wide_resnet101_2-kather100k": [1.0, 0.9999831914901733], + "densenet121-kather100k": [1.0, 1.0], + "densenet161-kather100k": [1.0, 0.9999959468841553], + "densenet169-kather100k": [1.0, 0.9999934434890747], + "densenet201-kather100k": [1.0, 0.9999983310699463], + "mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593], + "mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658], + "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], + "googlenet-kather100k": [1.0, 0.9999639987945557], + } + for model, expected_prob in pretrained_info.items(): + _test_classifier_output( + inputs, + model, + probabilities_check=expected_prob, + classification_check=[6, 3], + ) + + # cache mode + for model, expected_prob in pretrained_info.items(): + _test_classifier_output( + inputs, + model, + probabilities_check=expected_prob, + classification_check=[6, 3], + tmp_path=tmp_path, + ) diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py index 8f62f5037..4ec409e06 100644 --- a/tests/engines/test_patch_predictor.py +++ b/tests/engines/test_patch_predictor.py @@ -239,79 +239,19 @@ def test_wsi_predictor_api( shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) -def _test_predictor_output( - inputs: list, - model: str, - probabilities_check: list | None = None, - predictions_check: list | None = None, -) -> None: - """Test the predictions of multiple models included in tiatoolbox.""" - predictor = PatchPredictor( - model=model, - batch_size=32, - verbose=False, - ) - # don't run test on GPU - output = predictor.run( - inputs, - return_probabilities=True, - return_labels=False, - device=device, - ) - predictions = output["probabilities"] - for idx, probabilities_ in enumerate(predictions): - probabilities_max = max(probabilities_) - assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( - model, - probabilities_max, - probabilities_check[idx], - probabilities_, - predictions_check[idx], - ) - assert np.argmax(probabilities_) == predictions_check[idx], ( - model, - probabilities_max, - probabilities_check[idx], - probabilities_, - predictions_check[idx], - ) - - -def test_patch_predictor_kather100k_output( - sample_patch1: Path, - sample_patch2: Path, -) -> None: - """Test the output of patch prediction models on Kather100K dataset.""" - inputs = [Path(sample_patch1), Path(sample_patch2)] - pretrained_info = { - "alexnet-kather100k": [1.0, 0.9999735355377197], - "resnet18-kather100k": [1.0, 0.9999911785125732], - "resnet34-kather100k": [1.0, 0.9979840517044067], - "resnet50-kather100k": [1.0, 0.9999986886978149], - "resnet101-kather100k": [1.0, 0.9999932050704956], - "resnext50_32x4d-kather100k": [1.0, 0.9910059571266174], - "resnext101_32x8d-kather100k": [1.0, 0.9999971389770508], - "wide_resnet50_2-kather100k": [1.0, 0.9953408241271973], - "wide_resnet101_2-kather100k": [1.0, 0.9999831914901733], - "densenet121-kather100k": [1.0, 1.0], - "densenet161-kather100k": [1.0, 0.9999959468841553], - "densenet169-kather100k": [1.0, 0.9999934434890747], - "densenet201-kather100k": [1.0, 0.9999983310699463], - "mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593], - "mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658], - "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], - "googlenet-kather100k": [1.0, 0.9999639987945557], - } - for model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - model, - probabilities_check=expected_prob, - predictions_check=[6, 3], - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break +def _extract_probabilities_from_annotation_store(dbfile: str) -> dict: + """Helper function to extract probabilities from Annotation Store.""" + probs_dict = {} + con = sqlite3.connect(dbfile) + cur = con.cursor() + annotations_properties = list(cur.execute("SELECT properties FROM annotations")) + + for item in annotations_properties: + for json_str in item: + probs_dict = json.loads(json_str) + probs_dict.pop("prob_0") + + return probs_dict def _validate_probabilities(predictions: list | dict) -> bool: @@ -330,13 +270,13 @@ def test_wsi_predictor_zarr(sample_wsi_dict: dict, tmp_path: Path) -> None: """Test normal run of patch predictor for WSIs.""" mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - predictor = PatchPredictor( + classifier = PatchPredictor( model="alexnet-kather100k", batch_size=32, verbose=False, ) # don't run test on GPU - output = predictor.run( + output = classifier.run( images=[mini_wsi_svs], return_probabilities=True, return_labels=False, @@ -357,54 +297,6 @@ def test_wsi_predictor_zarr(sample_wsi_dict: dict, tmp_path: Path) -> None: assert _validate_probabilities(predictions=output_["probabilities"]) -def test_wsi_predictor_zarr_baseline(sample_wsi_dict: dict, tmp_path: Path) -> None: - """Test normal run of patch predictor for WSIs.""" - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - - predictor = PatchPredictor( - model="alexnet-kather100k", - batch_size=32, - verbose=False, - ) - # don't run test on GPU - output = predictor.run( - images=[mini_wsi_svs], - return_probabilities=True, - return_labels=False, - device=device, - patch_mode=False, - save_dir=tmp_path / "wsi_out_check", - units="baseline", - resolution=1.0, - ) - - assert output[mini_wsi_svs].exists() - - output_ = zarr.open(output[mini_wsi_svs]) - - assert output_["probabilities"].shape == (244, 9) # number of patches x classes - assert output_["probabilities"].ndim == 2 - # number of patches x [start_x, start_y, end_x, end_y] - assert output_["coordinates"].shape == (244, 4) - assert output_["coordinates"].ndim == 2 - assert _validate_probabilities(predictions=output_["probabilities"]) - - -def _extract_probabilities_from_annotation_store(dbfile: str) -> dict: - """Helper function to extract probabilities from Annotation Store.""" - probs_dict = {} - con = sqlite3.connect(dbfile) - cur = con.cursor() - annotations_properties = list(cur.execute("SELECT properties FROM annotations")) - - for item in annotations_properties: - for json_str in item: - probs_dict = json.loads(json_str) - probs_dict.pop("prob_0") - - return probs_dict - - def test_engine_run_wsi_annotation_store( sample_wsi_dict: dict, tmp_path: Path, diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 465230116..2e9d03d94 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -120,10 +120,6 @@ class EngineABCRunParams(TypedDict, total=False): Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. return_labels (bool): Whether to return the labels with the predictions. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map into a single file from a WSI. - This is only applicable if `patch_mode` is False in inference. num_loader_workers (int): Number of workers used in :class:`torch.utils.data.DataLoader`. num_post_proc_workers (int): @@ -165,7 +161,6 @@ class EngineABCRunParams(TypedDict, total=False): class_dict: dict device: str ioconfig: ModelIOConfigABC - merge_predictions: bool num_loader_workers: int num_post_proc_workers: int output_file: str @@ -248,10 +243,6 @@ class EngineABC(ABC): Runtime ioconfig. return_labels (bool): Whether to return the labels with the predictions. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable if `patch_mode` is False in inference. - Default is False. resolution (Resolution): Resolution used for reading the image. Please see :obj:`WSIReader` for details. @@ -293,8 +284,6 @@ class EngineABC(ABC): Number of workers to postprocess the results of the model. return_labels (bool): Whether to return the output labels. Default value is False. - merge_predictions (bool): - Whether to merge WSI predictions into a single file. Default value is False. resolution (Resolution): Resolution used for reading the image. Please see :class:`WSIReader` for details. @@ -374,7 +363,6 @@ def __init__( self.cache_mode: bool = False self.cache_size: int = self.batch_size if self.batch_size else 10000 self.labels: list | None = None - self.merge_predictions: bool = False self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers self.patch_input_shape: IntPair | None = None @@ -1194,8 +1182,6 @@ def run( - img_path: path of the input image. - raw: path to save location for raw prediction, saved in .json. - - merged: path to .npy contain merged - predictions if `merge_predictions` is `True`. Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] diff --git a/tiatoolbox/models/engine/patch_classifier.py b/tiatoolbox/models/engine/patch_classifier.py new file mode 100644 index 000000000..e68111a00 --- /dev/null +++ b/tiatoolbox/models/engine/patch_classifier.py @@ -0,0 +1,525 @@ +"""Defines PatchClassifier Engine.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import zarr +from typing_extensions import Unpack + +from .engine_abc import EngineABCRunParams +from .patch_predictor import PatchPredictor + +if TYPE_CHECKING: # pragma: no cover + import os + from pathlib import Path + + import numpy as np + + from tiatoolbox.annotation import AnnotationStore + from tiatoolbox.models.engine.io_config import ModelIOConfigABC + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.wsicore import WSIReader + + +class ClassifierRunParams(EngineABCRunParams): + """Class describing the input parameters for the :func:`EngineABC.run()` method. + + Attributes: + batch_size (int): + Number of image patches to feed to the model in a forward pass. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. + class_dict (dict): + Optional dictionary mapping classification outputs to class names. + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. + ioconfig (ModelIOConfigABC): + Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine. + return_labels (bool): + Whether to return the labels with the predictions. + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + output_file (str): + Output file name to save "zarr" or "db". If None, path to output is + returned by the engine. + patch_input_shape (tuple): + Shape of patches input to the model as tuple of height and width (HW). + Patches are requested at read resolution, not with respect to level 0, + and must be positive. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + return_labels (bool): + Whether to return the output labels. + return_probabilities (bool): + Whether to return per-class probabilities. + scale_factor (tuple[float, float]): + The scale factor to use when loading the + annotations. All coordinates will be multiplied by this factor to allow + conversion of annotations saved at non-baseline resolution to baseline. + Should be model_mpp/slide_mpp. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + verbose (bool): + Whether to output logging information. + + """ + + return_probabilities: bool + + +class PatchClassifier(PatchPredictor): + r"""Patch level classifier for digital histology images. + + The models provided by TIAToolbox should give the following results: + + .. list-table:: PatchClassifier performance on the Kather100K dataset [1] + :widths: 15 15 + :header-rows: 1 + + * - Model name + - F\ :sub:`1`\ score + * - alexnet-kather100k + - 0.965 + * - resnet18-kather100k + - 0.990 + * - resnet34-kather100k + - 0.991 + * - resnet50-kather100k + - 0.989 + * - resnet101-kather100k + - 0.989 + * - resnext50_32x4d-kather100k + - 0.992 + * - resnext101_32x8d-kather100k + - 0.991 + * - wide_resnet50_2-kather100k + - 0.989 + * - wide_resnet101_2-kather100k + - 0.990 + * - densenet121-kather100k + - 0.993 + * - densenet161-kather100k + - 0.992 + * - densenet169-kather100k + - 0.992 + * - densenet201-kather100k + - 0.991 + * - mobilenet_v2-kather100k + - 0.990 + * - mobilenet_v3_large-kather100k + - 0.991 + * - mobilenet_v3_small-kather100k + - 0.992 + * - googlenet-kather100k + - 0.992 + + .. list-table:: PatchClassifier performance on the PCam dataset [2] + :widths: 15 15 + :header-rows: 1 + + * - Model name + - F\ :sub:`1`\ score + * - alexnet-pcam + - 0.840 + * - resnet18-pcam + - 0.888 + * - resnet34-pcam + - 0.889 + * - resnet50-pcam + - 0.892 + * - resnet101-pcam + - 0.888 + * - resnext50_32x4d-pcam + - 0.900 + * - resnext101_32x8d-pcam + - 0.892 + * - wide_resnet50_2-pcam + - 0.901 + * - wide_resnet101_2-pcam + - 0.898 + * - densenet121-pcam + - 0.897 + * - densenet161-pcam + - 0.893 + * - densenet169-pcam + - 0.895 + * - densenet201-pcam + - 0.891 + * - mobilenet_v2-pcam + - 0.899 + * - mobilenet_v3_large-pcam + - 0.895 + * - mobilenet_v3_small-pcam + - 0.890 + * - googlenet-pcam + - 0.867 + + Args: + model (str | ModelABC): + A PyTorch model or name of pretrained model. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights using the `weights` parameter. Default is `None`. + batch_size (int): + Number of image patches fed into the model each time in a + forward/backward pass. Default value is 8. + num_loader_workers (int): + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. Default value is 0. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + Default value is 0. + weights (str or Path): + Path to the weight of the corresponding `model`. + + >>> engine = PatchClassifier( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default is "cpu". + verbose (bool): + Whether to output logging information. Default value is False. + + Attributes: + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. + masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of tissue masks or binary masks corresponding to processing area of + input images. These can be a list of numpy arrays or paths to + the saved image masks. These are only utilized when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (str): + Whether to treat input images as a set of image patches. TIAToolbox defines + an image as a patch if HWC of the input image matches with the HWC expected + by the model. If HWC of the input image does not match with the HWC expected + by the model, then the patch_mode must be set to False which will allow the + engine to extract patches from the input image. + In this case, when the patch_mode is False the input images are treated + as WSIs. Default value is True. + model (str | ModelABC): + A PyTorch model or a name of an existing model from the TIAToolbox model zoo + for processing the data. For a full list of pretrained models, + refer to the `docs + `_ + By default, the corresponding pretrained weights will also + be downloaded. However, you can override with your own set + of weights via the `weights` argument. Argument + is case-insensitive. + ioconfig (ModelIOConfigABC): + Input IO configuration of type :class:`ModelIOConfigABC` to run the Engine. + _ioconfig (ModelIOConfigABC): + Runtime ioconfig. + return_labels (bool): + Whether to return the labels with the predictions. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Shape of patches input to the model as tupled of HW. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. + batch_size (int): + Number of images fed into the model each time. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. cache_mode is always True when + processing WSIs i.e., when `patch_mode` is False. Default value is False. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. Default value is 10,000. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + return_labels (bool): + Whether to return the output labels. Default value is False. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. Default value is 1.0. + units (Units): + Units of resolution used for reading the image. Choose + from either `baseline`, `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. + Default value is `baseline`. + verbose (bool): + Whether to output logging information. Default value is False. + + Examples: + >>> # list of 2 image patches as input + >>> data = ['path/img.svs', 'path/img.svs'] + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') + + >>> # array of list of 2 image patches as input + >>> data = np.array([img1, img2]) + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') + + >>> # list of 2 image patch files as input + >>> data = ['path/img.png', 'path/img.png'] + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') + + >>> # list of 2 image tile files as input + >>> tile_file = ['path/tile1.png', 'path/tile2.png'] + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(tile_file, mode='tile') + + >>> # list of 2 wsi files as input + >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(wsi_file, mode='wsi') + + References: + [1] Kather, Jakob Nikolas, et al. "Predicting survival from colorectal cancer + histology slides using deep learning: A retrospective multicenter study." + PLoS medicine 16.1 (2019): e1002730. + + [2] Veeling, Bastiaan S., et al. "Rotation equivariant CNNs for digital + pathology." International Conference on Medical image computing and + computer-assisted intervention. Springer, Cham, 2018. + + """ + + def __init__( + self: PatchClassifier, + model: str | ModelABC, + batch_size: int = 8, + num_loader_workers: int = 0, + num_post_proc_workers: int = 0, + weights: str | Path | None = None, + *, + device: str = "cpu", + verbose: bool = True, + ) -> None: + """Initialize :class:`PatchClassifier`.""" + super().__init__( + model=model, + batch_size=batch_size, + num_loader_workers=num_loader_workers, + num_post_proc_workers=num_post_proc_workers, + weights=weights, + device=device, + verbose=verbose, + ) + + def post_process_cache_mode( + self: PatchClassifier, + raw_predictions: Path, + ) -> Path: + """Returns an array from raw predictions.""" + zarr_group = zarr.open(raw_predictions, mode="r+") + # Probabilities for post-processing + probabilities = zarr_group["probabilities"][:] + predictions = self.model.postproc_func( + probabilities, + ) + if "predictions" in zarr_group: + zarr_group["predictions"].append(predictions) + return raw_predictions + + zarr_dataset = zarr_group.create_dataset( + name="predictions", + shape=predictions.shape, + compressor=zarr_group["probabilities"].compressor, + ) + zarr_dataset[:] = predictions + + return raw_predictions + + def post_process_patches( + self: PatchClassifier, + raw_predictions: dict | Path, + **kwargs: Unpack[ClassifierRunParams], + ) -> dict | Path: + """Post-process raw patch predictions from inference. + + The output of :func:`infer_patches()` with patch prediction information will be + post-processed using this function. The processed output will be saved in the + respective input format. If `cache_mode` is True, the function processes the + input using zarr group with size specified by `cache_size`. + + Args: + raw_predictions (dict | Path): + A dictionary or path to zarr with patch prediction information. + **kwargs (ClassifierRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`ClassifierRunParams` for accepted keyword arguments. + + Returns: + dict or Path: + Returns patch based output after post-processing. Returns path to + saved zarr file if `cache_mode` is True. + + """ + _ = kwargs.get("return_probabilities") + if self.cache_mode: + return self.post_process_cache_mode(raw_predictions) + + probabilities = raw_predictions.get("probabilities") + + predictions = self.model.postproc_func( + probabilities, + ) + + if "predictions" in raw_predictions: + raw_predictions["predictions"].append(predictions) + else: + raw_predictions["predictions"] = predictions + + return raw_predictions + + def run( + self: PatchClassifier, + images: list[os | Path | WSIReader] | np.ndarray, + masks: list[os | Path] | np.ndarray | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, + *, + patch_mode: bool = True, + save_dir: os | Path | None = None, # None will not save output + overwrite: bool = False, + output_type: str = "dict", + **kwargs: Unpack[ClassifierRunParams], + ) -> AnnotationStore | Path | str | dict: + """Run the engine on input images. + + Args: + images (list, ndarray): + List of inputs to process. when using `patch` mode, the + input must be either a list of images, a list of image + file paths or a numpy array of an image list. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = True. + ioconfig (IOPatchPredictorConfig): + IO configuration. + save_dir (str or pathlib.Path): + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. + overwrite (bool): + Whether to overwrite the results. Default = False. + output_type (str): + The format of the output type. "output_type" can be + "zarr" or "AnnotationStore". Default value is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. + **kwargs (ClassifierRunParams): + Keyword Args to update :class:`EngineABC` attributes during runtime. + + Returns: + (:class:`numpy.ndarray`, dict): + Model predictions of the input dataset. If multiple + whole slide images are provided as input, + or save_output is True, then results are saved to + `save_dir` and a dictionary indicating save location for + each input is returned. + + The dict has the following format: + + - img_path: path of the input image. + - raw: path to save location for raw prediction, + saved in .json. + + Examples: + >>> wsis = ['wsi1.svs', 'wsi2.svs'] + >>> class PatchClassifier(PatchPredictor): + >>> # Define all Abstract methods. + >>> ... + >>> classifier = PatchClassifier(model="resnet18-kather100k") + >>> output = classifier.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + >>> output = classifier.run( + >>> image_patches, + >>> patch_mode=True, + >>> output_type="zarr") + >>> output + ... "/path/to/Output.zarr" + >>> output = classifier.run(wsis, patch_mode=False) + >>> output.keys() + ... ['wsi1.svs', 'wsi2.svs'] + >>> output['wsi1.svs'] + ... {'/path/to/wsi1.db'} + + """ + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, + ) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index b98c6676d..3d4484b1b 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -1,4 +1,4 @@ -"""Defines Abstract Base Class for TIAToolbox Model Engines.""" +"""Defines PatchPredictor Engine.""" from __future__ import annotations @@ -25,7 +25,7 @@ class PatchPredictor(EngineABC): r"""Patch level predictor for digital histology images. - The models provided by tiatoolbox should give the following results: + The models provided by TIAToolbox should give the following results: .. list-table:: PatchPredictor performance on the Kather100K dataset [1] :widths: 15 15 @@ -176,10 +176,6 @@ class PatchPredictor(EngineABC): Runtime ioconfig. return_labels (bool): Whether to return the labels with the predictions. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable if `patch_mode` is False in inference. - Default is False. resolution (Resolution): Resolution used for reading the image. Please see :obj:`WSIReader` for details. @@ -221,8 +217,6 @@ class PatchPredictor(EngineABC): Number of workers to postprocess the results of the model. return_labels (bool): Whether to return the output labels. Default value is False. - merge_predictions (bool): - Whether to merge WSI predictions into a single file. Default value is False. resolution (Resolution): Resolution used for reading the image. Please see :class:`WSIReader` for details. @@ -249,7 +243,7 @@ class PatchPredictor(EngineABC): >>> # array of list of 2 image patches as input >>> data = np.array([img1, img2]) >>> predictor = PatchPredictor(model="resnet18-kather100k") - >>> output = predictor.predict(data, mode='patch') + >>> output = predictor.run(data, mode='patch') >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] @@ -418,7 +412,7 @@ def save_wsi_output( ) def run( - self: EngineABC, + self: PatchPredictor, images: list[os | Path | WSIReader] | np.ndarray, masks: list[os | Path] | np.ndarray | None = None, labels: list | None = None, @@ -482,8 +476,6 @@ def run( - img_path: path of the input image. - raw: path to save location for raw prediction, saved in .json. - - merged: path to .npy contain merged - predictions if `merge_predictions` is `True`. Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 5a3c8c7fe..4c662184f 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1225,7 +1225,7 @@ def patch_predictions_as_annotations( return annotations -def _get_zarr_array(zarr_array: zarr.core.Array | np.ndarray) -> np.ndarray: +def get_zarr_array(zarr_array: zarr.core.Array | np.ndarray) -> np.ndarray: """Converts a zarr array into a numpy array.""" if isinstance(zarr_array, zarr.core.Array): return zarr_array[:] @@ -1269,8 +1269,8 @@ def dict_to_store( raise ValueError(msg) # get relevant keys - class_probs = _get_zarr_array(patch_output.get("probabilities", [])) - preds = _get_zarr_array(patch_output.get("predictions", [])) + class_probs = get_zarr_array(patch_output.get("probabilities", [])) + preds = get_zarr_array(patch_output.get("predictions", [])) patch_coords = np.array(patch_output.get("coordinates", [])) if not np.all(np.array(scale_factor) == 1):