diff --git a/tests/engines/test_engine_abc.py b/tests/engines/test_engine_abc.py index d0d6f0bd5..a5ce07d30 100644 --- a/tests/engines/test_engine_abc.py +++ b/tests/engines/test_engine_abc.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import logging import shutil from pathlib import Path from typing import TYPE_CHECKING, NoReturn @@ -11,6 +12,7 @@ import pytest import torchvision.models as torch_models import zarr +from typing_extensions import Unpack from tiatoolbox.models.architecture import ( fetch_pretrained_weights, @@ -18,8 +20,13 @@ ) from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.dataset import PatchDataset, WSIPatchDataset -from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir +from tiatoolbox.models.engine.engine_abc import ( + EngineABC, + EngineABCRunParams, + prepare_engines_save_dir, +) from tiatoolbox.models.engine.io_config import ModelIOConfigABC +from tiatoolbox.utils.misc import write_to_zarr_in_cache_mode if TYPE_CHECKING: import torch.nn @@ -57,31 +64,38 @@ def get_dataloader( def save_wsi_output( self: EngineABC, - raw_output: dict, + processed_output: dict, save_dir: Path, **kwargs: dict, ) -> Path: """Test post_process_wsi.""" return super().save_wsi_output( - raw_output, + processed_output, save_dir=save_dir, **kwargs, ) + def post_process_wsi( + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | Path: + """Post process WSI output.""" + return super().post_process_wsi( + raw_predictions=raw_predictions, + **kwargs, + ) + def infer_wsi( self: EngineABC, dataloader: torch.utils.data.DataLoader, - img_label: str, - highest_input_resolution: list[dict], - save_dir: Path, + save_path: Path, **kwargs: dict, ) -> dict | np.ndarray: """Test infer_wsi.""" return super().infer_wsi( dataloader, - img_label, - highest_input_resolution, - save_dir, + save_path, **kwargs, ) @@ -115,13 +129,34 @@ def test_incorrect_ioconfig() -> NoReturn: """Test EngineABC initialization with incorrect ioconfig.""" model = torch_models.resnet18() engine = TestEngineABC(model=model) + with pytest.raises( ValueError, - match=r".*provide a valid ModelIOConfigABC.*", + match=r".*Must provide.*`ioconfig`.*", ): engine.run(images=[], masks=[], ioconfig=None) +def test_incorrect_output_type() -> NoReturn: + """Test EngineABC for incorrect output type.""" + pretrained_model = "alexnet-kather100k" + + # Test engine run without ioconfig + eng = TestEngineABC(model=pretrained_model) + + with pytest.raises( + TypeError, + match=r".*output_type must be 'dict' or 'zarr' or 'annotationstore*", + ): + _ = eng.run( + images=np.zeros((10, 224, 224, 3), dtype=np.uint8), + on_gpu=False, + patch_mode=True, + ioconfig=None, + output_type="random", + ) + + def test_pretrained_ioconfig() -> NoReturn: """Test EngineABC initialization with pretrained model name in the toolbox.""" pretrained_model = "alexnet-kather100k" @@ -134,7 +169,7 @@ def test_pretrained_ioconfig() -> NoReturn: patch_mode=True, ioconfig=None, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" not in out @@ -153,7 +188,7 @@ def test_ioconfig() -> NoReturn: ioconfig=ioconfig, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" not in out @@ -260,7 +295,7 @@ def test_engine_initalization() -> NoReturn: assert isinstance(eng, EngineABC) -def test_engine_run(tmp_path: Path) -> NoReturn: +def test_engine_run(tmp_path: Path, sample_svs: Path) -> NoReturn: """Test engine run.""" eng = TestEngineABC(model="alexnet-kather100k") assert isinstance(eng, EngineABC) @@ -316,7 +351,7 @@ def test_engine_run(tmp_path: Path) -> NoReturn: on_gpu=False, patch_mode=True, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" not in out eng = TestEngineABC(model="alexnet-kather100k") @@ -325,7 +360,7 @@ def test_engine_run(tmp_path: Path) -> NoReturn: on_gpu=False, verbose=False, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" not in out eng = TestEngineABC(model="alexnet-kather100k") @@ -334,14 +369,14 @@ def test_engine_run(tmp_path: Path) -> NoReturn: labels=list(range(10)), on_gpu=False, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" in out eng = TestEngineABC(model="alexnet-kather100k") with pytest.raises(NotImplementedError): eng.run( - images=np.zeros(shape=(10, 224, 224, 3)), + images=[sample_svs], save_dir=tmp_path / "output", patch_mode=False, ) @@ -358,7 +393,7 @@ def test_engine_run_with_verbose() -> NoReturn: on_gpu=False, ) - assert "predictions" in out + assert "probabilities" in out assert "labels" in out @@ -513,7 +548,10 @@ def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None: save_path = tmp_path / "output.zarr" _ = zarr.open(save_path, mode="w") out = eng.save_wsi_output( - raw_output=save_path, save_path=save_path, output_type="zarr", save_dir=tmp_path + processed_output=save_path, + save_path=save_path, + output_type="zarr", + save_dir=tmp_path, ) assert out.exists() @@ -521,13 +559,17 @@ def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None: # Test AnnotationStore patch_output = { - "predictions": [1, 0, 1], - "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], - "other": "other", + "predictions": np.array([1, 0, 1]), + "coordinates": np.array([(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)]), } class_dict = {0: "class0", 1: "class1"} + save_path = tmp_path / "output_db.zarr" + zarr_group = zarr.open(save_path, mode="w") + _ = write_to_zarr_in_cache_mode( + zarr_group=zarr_group, output_data_to_save=patch_output + ) out = eng.save_wsi_output( - raw_output=patch_output, + processed_output=save_path, scale_factor=(1.0, 1.0), class_dict=class_dict, save_dir=tmp_path, @@ -542,28 +584,35 @@ def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None: match=r".*supports zarr and AnnotationStore as output_type.", ): eng.save_wsi_output( - raw_output=save_path, + processed_output=save_path, save_path=save_path, output_type="dict", save_dir=tmp_path, ) -def test_io_config_delegation(tmp_path: Path) -> None: +def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None: """Test for delegating args to io config.""" # test not providing config / full input info for not pretrained models model = CNNModel("resnet50") eng = TestEngineABC(model=model) - with pytest.raises(ValueError, match=r".*Please provide a valid ModelIOConfigABC*"): - eng.run( - np.zeros((10, 224, 224, 3)), patch_mode=True, save_dir=tmp_path / "dump" - ) kwargs = { "patch_input_shape": [512, 512], "resolution": 1.75, "units": "mpp", } + with caplog.at_level(logging.WARNING): + eng.run( + np.zeros((10, 224, 224, 3)), + patch_mode=True, + save_dir=tmp_path / "dump", + patch_input_shape=kwargs["patch_input_shape"], + resolution=kwargs["resolution"], + units=kwargs["units"], + ) + assert "provide a valid ModelIOConfigABC" in caplog.text + shutil.rmtree(tmp_path / "dump", ignore_errors=True) # test providing config / full input info for non pretrained models ioconfig = ModelIOConfigABC( diff --git a/tests/engines/test_patch_predictor.py b/tests/engines/test_patch_predictor.py index ab59efc53..8f62f5037 100644 --- a/tests/engines/test_patch_predictor.py +++ b/tests/engines/test_patch_predictor.py @@ -3,628 +3,124 @@ from __future__ import annotations import copy +import json import shutil +import sqlite3 from pathlib import Path from typing import Callable -import cv2 import numpy as np -import pytest -import torch +import zarr from click.testing import CliRunner from tiatoolbox import cli -from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor +from tiatoolbox.models import IOPatchPredictorConfig from tiatoolbox.models.architecture.vanilla import CNNModel -from tiatoolbox.models.dataset import ( - PatchDataset, - PatchDatasetABC, - WSIPatchDataset, - predefined_preproc_func, -) -from tiatoolbox.utils import download_data, imread, imwrite +from tiatoolbox.models.engine.patch_predictor import PatchPredictor +from tiatoolbox.utils import download_data, imwrite from tiatoolbox.utils import env_detection as toolbox_env -from tiatoolbox.wsicore.wsireader import WSIReader +device = "cuda" if toolbox_env.has_gpu() else "cpu" ON_GPU = toolbox_env.has_gpu() RNG = np.random.default_rng() # Numpy Random Generator -# ------------------------------------------------------------------------------------- -# Dataloader -# ------------------------------------------------------------------------------------- - - -def test_patch_dataset_path_imgs( - sample_patch1: str | Path, - sample_patch2: str | Path, -) -> None: - """Test for patch dataset with a list of file paths as input.""" - size = (224, 224, 3) - - dataset = PatchDataset([Path(sample_patch1), Path(sample_patch2)]) - - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - -def test_patch_dataset_list_imgs(tmp_path: Path) -> None: - """Test for patch dataset with a list of images as input.""" - save_dir_path = tmp_path - - size = (5, 5, 3) - img = RNG.integers(low=0, high=255, size=size) - list_imgs = [img, img, img] - dataset = PatchDataset(list_imgs) - - dataset.preproc_func = lambda x: x - - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - # test for changing to another preproc - dataset.preproc_func = lambda x: x - 10 - item = dataset[0] - assert np.sum(item["image"] - (list_imgs[0] - 10)) == 0 - - # * test for loading npy - # remove previously generated data - if Path.exists(save_dir_path): - shutil.rmtree(save_dir_path, ignore_errors=True) - Path.mkdir(save_dir_path, parents=True) - np.save( - str(save_dir_path / "sample2.npy"), - RNG.integers(0, 255, (4, 4, 3)), - ) - imgs = [ - save_dir_path / "sample2.npy", - ] - _ = PatchDataset(imgs) - assert imgs[0] is not None - # test for path object - imgs = [ - save_dir_path / "sample2.npy", - ] - _ = PatchDataset(imgs) - - -def test_patch_datasetarray_imgs() -> None: - """Test for patch dataset with a numpy array of a list of images.""" - size = (5, 5, 3) - img = RNG.integers(0, 255, size=size) - list_imgs = [img, img, img] - labels = [1, 2, 3] - array_imgs = np.array(list_imgs) - - # test different setter for label - dataset = PatchDataset(array_imgs, labels=labels) - an_item = dataset[2] - assert an_item["label"] == 3 - dataset = PatchDataset(array_imgs, labels=None) - an_item = dataset[2] - assert "label" not in an_item - - dataset = PatchDataset(array_imgs) - for _, sample_data in enumerate(dataset): - sampled_img_shape = sample_data["image"].shape - assert sampled_img_shape[0] == size[0] - assert sampled_img_shape[1] == size[1] - assert sampled_img_shape[2] == size[2] - - -def test_patch_dataset_crash(tmp_path: Path) -> None: - """Test to make sure patch dataset crashes with incorrect input.""" - # all below examples should fail when input to PatchDataset - save_dir_path = tmp_path - - # not supported input type - imgs = {"a": RNG.integers(0, 255, (4, 4, 4))} - with pytest.raises( - ValueError, - match=r".*Input must be either a list/array of images.*", - ): - _ = PatchDataset(imgs) - - # ndarray of mixed dtype - imgs = np.array( - # string array of the same shape - [ - RNG.integers(0, 255, (4, 5, 3)), - np.array( # skipcq: PYL-E1121 - ["you_should_crash_here" for _ in range(4 * 5 * 3)], - ).reshape( - 4, - 5, - 3, - ), - ], - dtype=object, - ) - with pytest.raises(ValueError, match="Provided input array is non-numerical."): - _ = PatchDataset(imgs) - - # ndarray(s) of NHW images - imgs = RNG.integers(0, 255, (4, 4, 4)) - with pytest.raises(ValueError, match=r".*array of the form HWC*"): - _ = PatchDataset(imgs) - - # list of ndarray(s) with different sizes - imgs = [ - RNG.integers(0, 255, (4, 4, 3)), - RNG.integers(0, 255, (4, 5, 3)), - ] - with pytest.raises(ValueError, match="Images must have the same dimensions."): - _ = PatchDataset(imgs) - - # list of ndarray(s) with HW and HWC mixed up - imgs = [ - RNG.integers(0, 255, (4, 4, 3)), - RNG.integers(0, 255, (4, 4)), - ] - with pytest.raises( - ValueError, - match="Each sample must be an array of the form HWC.", - ): - _ = PatchDataset(imgs) - - # list of mixed dtype - imgs = [RNG.integers(0, 255, (4, 4, 3)), "you_should_crash_here", 123, 456] - with pytest.raises( - ValueError, - match="Input must be either a list/array of images or a list of " - "valid image paths.", - ): - _ = PatchDataset(imgs) - - # list of mixed dtype - imgs = ["you_should_crash_here", 123, 456] - with pytest.raises( - ValueError, - match="Input must be either a list/array of images or a list of " - "valid image paths.", - ): - _ = PatchDataset(imgs) - - # list not exist paths - with pytest.raises( - ValueError, - match=r".*valid image paths.*", - ): - _ = PatchDataset(["img.npy"]) - - # ** test different extension parser - # save dummy data to temporary location - # remove prev generated data - shutil.rmtree(save_dir_path, ignore_errors=True) - save_dir_path.mkdir(parents=True) - - torch.save({"a": "a"}, save_dir_path / "sample1.tar") - np.save( - str(save_dir_path / "sample2.npy"), - RNG.integers(0, 255, (4, 4, 3)), - ) - - imgs = [ - save_dir_path / "sample1.tar", - save_dir_path / "sample2.npy", - ] - with pytest.raises( - ValueError, - match="Cannot load image data from", - ): - _ = PatchDataset(imgs) - - # preproc func for not defined dataset - with pytest.raises( - ValueError, - match=r".* preprocessing .* does not exist.", - ): - predefined_preproc_func("secret-dataset") - - -def test_wsi_patch_dataset( # noqa: PLR0915 - sample_wsi_dict: dict, - tmp_path: Path, -) -> None: - """A test for creation and bare output.""" - # convert to pathlib Path to prevent wsireader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - def reuse_init(img_path: Path = mini_wsi_svs, **kwargs: dict) -> WSIPatchDataset: - """Testing function.""" - return WSIPatchDataset(img_path=img_path, **kwargs) - - def reuse_init_wsi(**kwargs: dict) -> WSIPatchDataset: - """Testing function.""" - return reuse_init(mode="wsi", **kwargs) - - # test for ABC validate - # intentionally created to check error - # skipcq - class Proto(PatchDatasetABC): - def __init__(self: Proto) -> None: - super().__init__() - self.inputs = "CRASH" - self._check_input_integrity("wsi") - - # skipcq - def __getitem__(self: Proto, idx: int) -> object: - """Get an item from the dataset.""" - - with pytest.raises( - ValueError, - match=r".*`inputs` should be a list of patch coordinates.*", - ): - Proto() # skipcq - - # invalid path input - with pytest.raises(ValueError, match=r".*`img_path` must be a valid file path.*"): - WSIPatchDataset( - img_path="aaaa", - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - ) - - # invalid mask path input - with pytest.raises(ValueError, match=r".*`mask_path` must be a valid file path.*"): - WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path="aaaa", - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - resolution=1.0, - units="mpp", - auto_get_mask=False, - ) - - # invalid mode - with pytest.raises(ValueError, match="`X` is not supported."): - reuse_init(mode="X") - - # invalid patch - with pytest.raises(ValueError, match="Invalid `patch_input_shape` value None."): - reuse_init() - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \[512 512 512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512, 512]) - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \['512' 'a'\].", - ): - reuse_init_wsi(patch_input_shape=[512, "a"]) - with pytest.raises(ValueError, match="Invalid `stride_shape` value None."): - reuse_init_wsi(patch_input_shape=512) - # invalid stride - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \['512' 'a'\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, "a"]) - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \[512 512 512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, 512, 512]) - # negative - with pytest.raises( - ValueError, - match=r"Invalid `patch_input_shape` value \[ 512 -512\].", - ): - reuse_init_wsi(patch_input_shape=[512, -512], stride_shape=[512, 512]) - with pytest.raises( - ValueError, - match=r"Invalid `stride_shape` value \[ 512 -512\].", - ): - reuse_init_wsi(patch_input_shape=[512, 512], stride_shape=[512, -512]) - - # * for wsi - # dummy test for analysing the output - # stride and patch size should be as expected - patch_size = [512, 512] - stride_size = [256, 256] - ds = reuse_init_wsi( - patch_input_shape=patch_size, - stride_shape=stride_size, - resolution=1.0, - units="mpp", - auto_get_mask=False, - ) - reader = WSIReader.open(mini_wsi_svs) - # tiling top to bottom, left to right - ds_roi = ds[2]["image"] - step_idx = 2 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - rd_roi = reader.read_bounds( - start + end, - resolution=1.0, - units="mpp", - coord_space="resolution", - ) - correlation = np.corrcoef( - cv2.cvtColor(ds_roi, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(rd_roi, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert ds_roi.shape[0] == rd_roi.shape[0] - assert ds_roi.shape[1] == rd_roi.shape[1] - assert np.min(correlation) > 0.9, correlation - - # test creation with auto mask gen and input mask - ds = reuse_init_wsi( - patch_input_shape=patch_size, - stride_shape=stride_size, - resolution=1.0, - units="mpp", - auto_get_mask=True, - ) - assert len(ds) > 0 - ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=mini_wsi_msk, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - resolution=1.0, - units="mpp", - ) - negative_mask = imread(mini_wsi_msk) - negative_mask = np.zeros_like(negative_mask) - negative_mask_path = tmp_path / "negative_mask.png" - imwrite(negative_mask_path, negative_mask) - with pytest.raises(ValueError, match="No patch coordinates remain after filtering"): - ds = WSIPatchDataset( - img_path=mini_wsi_svs, - mask_path=negative_mask_path, - mode="wsi", - patch_input_shape=[512, 512], - stride_shape=[256, 256], - auto_get_mask=False, - resolution=1.0, - units="mpp", - ) - - # * for tile - reader = WSIReader.open(mini_wsi_jpg) - tile_ds = WSIPatchDataset( - img_path=mini_wsi_jpg, - mode="tile", - patch_input_shape=patch_size, - stride_shape=stride_size, - auto_get_mask=False, - ) - step_idx = 3 # manually calibrate - start = (step_idx * stride_size[1], 0) - end = (start[0] + patch_size[0], start[1] + patch_size[1]) - roi2 = reader.read_bounds( - start + end, - resolution=1.0, - units="baseline", - coord_space="resolution", - ) - roi1 = tile_ds[3]["image"] # match with step_index - correlation = np.corrcoef( - cv2.cvtColor(roi1, cv2.COLOR_RGB2GRAY).flatten(), - cv2.cvtColor(roi2, cv2.COLOR_RGB2GRAY).flatten(), - ) - assert roi1.shape[0] == roi2.shape[0] - assert roi1.shape[1] == roi2.shape[1] - assert np.min(correlation) > 0.9, correlation - - -def test_patch_dataset_abc() -> None: - """Test for ABC methods. - - Test missing definition for abstract intentionally created to check error. - - """ - - # skipcq - class Proto(PatchDatasetABC): - # skipcq - def __init__(self: Proto) -> None: - super().__init__() - - # crash due to undefined __getitem__ - with pytest.raises(TypeError): - Proto() # skipcq - - # skipcq - class Proto(PatchDatasetABC): - # skipcq - def __init__(self: Proto) -> None: - super().__init__() - - # skipcq - def __getitem__(self: Proto, idx: int) -> None: - """Get an item from the dataset.""" - - ds = Proto() # skipcq - - # test setter and getter - assert ds.preproc_func(1) == 1 - ds.preproc_func = lambda x: x - 1 # skipcq: PYL-W0201 - assert ds.preproc_func(1) == 0 - assert ds.preproc(1) == 1, "Must be unchanged!" - ds.preproc_func = None # skipcq: PYL-W0201 - assert ds.preproc_func(2) == 2 - - # test assign uncallable to preproc_func/postproc_func - with pytest.raises(ValueError, match=r".*callable*"): - ds.preproc_func = 1 # skipcq: PYL-W0201 - - -# ------------------------------------------------------------------------------------- -# Dataloader -# ------------------------------------------------------------------------------------- - - -def test_io_patch_predictor_config() -> None: - """Test for IOConfig.""" - # test for creating - cfg = IOPatchPredictorConfig( - patch_input_shape=[224, 224], - stride_shape=[224, 224], - input_resolutions=[{"resolution": 0.5, "units": "mpp"}], - # test adding random kwarg and they should be accessible as kwargs - crop_from_source=True, - ) - assert cfg.crop_from_source - # ------------------------------------------------------------------------------------- # Engine # ------------------------------------------------------------------------------------- -def test_predictor_crash(tmp_path: Path) -> None: - """Test for crash when making predictor.""" - # without providing any model - with pytest.raises(ValueError, match=r"Must provide.*"): - PatchPredictor() - - # provide wrong unknown pretrained model - with pytest.raises(ValueError, match=r"Pretrained .* does not exist"): - PatchPredictor(pretrained_model="secret_model-kather100k") - - # provide wrong model of unknown type, deprecated later with type hint - with pytest.raises(TypeError, match=r".*must be a string.*"): - PatchPredictor(pretrained_model=123) - - # test predict crash - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) - - with pytest.raises(ValueError, match=r".*not a valid mode.*"): - predictor.predict("aaa", mode="random", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(TypeError, match=r".*must be a list of file paths.*"): - predictor.predict("aaa", mode="wsi", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - with pytest.raises(ValueError, match=r".*masks.*!=.*imgs.*"): - predictor.predict([1, 2, 3], masks=[1, 2], mode="wsi", save_dir=tmp_path) - with pytest.raises(ValueError, match=r".*labels.*!=.*imgs.*"): - predictor.predict([1, 2, 3], labels=[1, 2], mode="patch", save_dir=tmp_path) - # remove previously generated data - shutil.rmtree(tmp_path / "output", ignore_errors=True) - - def test_io_config_delegation(remote_sample: Callable, tmp_path: Path) -> None: """Test for delegating args to io config.""" mini_wsi_svs = Path(remote_sample("wsi2_4k_4k_svs")) - - # test not providing config / full input info for not pretrained models model = CNNModel("resnet50") - predictor = PatchPredictor(model=model) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict([mini_wsi_svs], mode="wsi", save_dir=tmp_path / "dump") - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - + predictor = PatchPredictor(model=model, weights=None) kwargs = { "patch_input_shape": [512, 512], "resolution": 1.75, "units": "mpp", } - for key in kwargs: - _kwargs = copy.deepcopy(kwargs) - _kwargs.pop(key) - with pytest.raises(ValueError, match=r".*Must provide.*`ioconfig`.*"): - predictor.predict( - [mini_wsi_svs], - mode="wsi", - save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, - **_kwargs, - ) - shutil.rmtree(tmp_path / "dump", ignore_errors=True) - - # test providing config / full input info for not pretrained models + + # test providing config / full input info for default models without weights ioconfig = IOPatchPredictorConfig( patch_input_shape=(512, 512), stride_shape=(256, 256), input_resolutions=[{"resolution": 1.35, "units": "mpp"}], ) - predictor.predict( - [mini_wsi_svs], + predictor.run( + images=[mini_wsi_svs], ioconfig=ioconfig, - mode="wsi", + patch_mode=False, save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, ) shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor.predict( - [mini_wsi_svs], - mode="wsi", + predictor.run( + images=[mini_wsi_svs], + patch_mode=False, save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, **kwargs, ) shutil.rmtree(tmp_path / "dump", ignore_errors=True) # test overwriting pretrained ioconfig - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - predictor.predict( - [mini_wsi_svs], + predictor = PatchPredictor(model="resnet18-kather100k", batch_size=1) + predictor.run( + images=[mini_wsi_svs], patch_input_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, + patch_mode=False, save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.patch_input_shape == (300, 300) shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor.predict( - [mini_wsi_svs], + predictor.run( + images=[mini_wsi_svs], stride_shape=(300, 300), - mode="wsi", - on_gpu=ON_GPU, + patch_mode=False, save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.stride_shape == (300, 300) shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor.predict( - [mini_wsi_svs], + predictor.run( + images=[mini_wsi_svs], resolution=1.99, - mode="wsi", - on_gpu=ON_GPU, + patch_mode=False, save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.input_resolutions[0]["resolution"] == 1.99 shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor.predict( - [mini_wsi_svs], + predictor.run( + images=[mini_wsi_svs], units="baseline", - mode="wsi", - on_gpu=ON_GPU, + patch_mode=False, save_dir=f"{tmp_path}/dump", ) assert predictor._ioconfig.input_resolutions[0]["units"] == "baseline" shutil.rmtree(tmp_path / "dump", ignore_errors=True) - predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - predictor.predict( - [mini_wsi_svs], - mode="wsi", - merge_predictions=True, + predictor.run( + images=[mini_wsi_svs], + units="level", + resolution=0, + patch_mode=False, + save_dir=f"{tmp_path}/dump", + ) + assert predictor._ioconfig.input_resolutions[0]["units"] == "level" + assert predictor._ioconfig.input_resolutions[0]["resolution"] == 0 + shutil.rmtree(tmp_path / "dump", ignore_errors=True) + + predictor.run( + images=[mini_wsi_svs], + units="power", + resolution=20, + patch_mode=False, save_dir=f"{tmp_path}/dump", - on_gpu=ON_GPU, ) + assert predictor._ioconfig.input_resolutions[0]["units"] == "power" + assert predictor._ioconfig.input_resolutions[0]["resolution"] == 20 shutil.rmtree(tmp_path / "dump", ignore_errors=True) @@ -638,59 +134,28 @@ def test_patch_predictor_api( # convert to pathlib Path to prevent reader complaint inputs = [Path(sample_patch1), Path(sample_patch2)] - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) + predictor = PatchPredictor(model="resnet18-kather100k", batch_size=1) # don't run test on GPU - output = predictor.predict( + # Default run + output = predictor.run( inputs, - on_gpu=ON_GPU, - save_dir=save_dir_path, + device="cpu", ) - assert sorted(output.keys()) == ["predictions"] - assert len(output["predictions"]) == 2 + assert sorted(output.keys()) == ["probabilities"] + assert len(output["probabilities"]) == 2 shutil.rmtree(save_dir_path, ignore_errors=True) - output = predictor.predict( + # whether to return labels + output = predictor.run( inputs, - labels=[1, "a"], + labels=["1", "a"], return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, ) - assert sorted(output.keys()) == sorted(["labels", "predictions"]) - assert len(output["predictions"]) == len(output["labels"]) - assert output["labels"] == [1, "a"] + assert sorted(output.keys()) == sorted(["labels", "probabilities"]) + assert len(output["probabilities"]) == len(output["labels"]) + assert output["labels"].tolist() == ["1", "a"] shutil.rmtree(save_dir_path, ignore_errors=True) - output = predictor.predict( - inputs, - return_probabilities=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["probabilities"]) - shutil.rmtree(save_dir_path, ignore_errors=True) - - output = predictor.predict( - inputs, - return_probabilities=True, - labels=[1, "a"], - return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, - ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) - - # test saving output, should have no effect - _ = predictor.predict( - inputs, - on_gpu=ON_GPU, - save_dir="special_dir_not_exist", - ) - assert not Path.is_dir(Path("special_dir_not_exist")) - # test loading user weight pretrained_weights_url = ( "https://tiatoolbox.dcs.warwick.ac.uk/models/pc/resnet18-kather100k.pth" @@ -705,33 +170,31 @@ def test_patch_predictor_api( download_data(pretrained_weights_url, pretrained_weights) - _ = PatchPredictor( - pretrained_model="resnet18-kather100k", - pretrained_weights=pretrained_weights, + predictor = PatchPredictor( + model="resnet18-kather100k", + weights=pretrained_weights, batch_size=1, ) + ioconfig = predictor.ioconfig # --- test different using user model model = CNNModel(backbone="resnet18", num_classes=9) # test prediction predictor = PatchPredictor(model=model, batch_size=1, verbose=False) - output = predictor.predict( + output = predictor.run( inputs, - return_probabilities=True, - labels=[1, "a"], + labels=[1, 2], return_labels=True, - on_gpu=ON_GPU, - save_dir=save_dir_path, + ioconfig=ioconfig, ) - assert sorted(output.keys()) == sorted(["labels", "predictions", "probabilities"]) - assert len(output["predictions"]) == len(output["labels"]) - assert len(output["predictions"]) == len(output["probabilities"]) + assert sorted(output.keys()) == sorted(["labels", "probabilities"]) + assert len(output["probabilities"]) == len(output["labels"]) + assert output["labels"].tolist() == [1, 2] def test_wsi_predictor_api( sample_wsi_dict: dict, tmp_path: Path, - chdir: Callable, ) -> None: """Test normal run of wsi predictor.""" save_dir_path = tmp_path @@ -742,15 +205,12 @@ def test_wsi_predictor_api( mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) patch_size = np.array([224, 224]) - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=32) + predictor = PatchPredictor(model="resnet18-kather100k", batch_size=32) save_dir = f"{save_dir_path}/model_wsi_output" # wrapper to make this more clean kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, "patch_input_shape": patch_size, "stride_shape": patch_size, "resolution": 1.0, @@ -759,236 +219,60 @@ def test_wsi_predictor_api( } # ! add this test back once the read at `baseline` is fixed # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - shutil.rmtree(save_dir, ignore_errors=True) - - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - # remove previously generated data - shutil.rmtree(save_dir, ignore_errors=True) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": patch_size, - "stride_shape": patch_size, - "resolution": 0.5, - "save_dir": save_dir, - "merge_predictions": True, # to test the api coverage - "units": "mpp", - } - - _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = False - # test reading of multiple whole-slide images - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" not in output_info - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - # coverage test _kwargs = copy.deepcopy(kwargs) - _kwargs["merge_predictions"] = True # test reading of multiple whole-slide images - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], + output = predictor.run( + images=[mini_wsi_svs, mini_wsi_jpg], masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", + patch_mode=False, **_kwargs, ) - _kwargs = copy.deepcopy(kwargs) - with pytest.raises(FileExistsError): - predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - # remove previously generated data - shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) - - with chdir(save_dir_path): - # test reading of multiple whole-slide images - _kwargs = copy.deepcopy(kwargs) - _kwargs["save_dir"] = None # default coverage - _kwargs["return_probabilities"] = False - output = predictor.predict( - [mini_wsi_svs, mini_wsi_svs], - masks=[mini_wsi_msk, mini_wsi_msk], - mode="wsi", - **_kwargs, - ) - assert Path.exists(Path("output")) - for output_info in output.values(): - assert Path(output_info["raw"]).exists() - assert "merged" in output_info - assert Path(output_info["merged"]).exists() - - # remove previously generated data - shutil.rmtree("output", ignore_errors=True) - - -def test_wsi_predictor_merge_predictions(sample_wsi_dict: dict) -> None: - """Test normal run of wsi predictor with merge predictions option.""" - # convert to pathlib Path to prevent reader complaint - mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) - mini_wsi_jpg = Path(sample_wsi_dict["wsi2_4k_4k_jpg"]) - mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) - - # blind test - # pseudo output dict from model with 2 patches - output = { - "resolution": 1.0, - "units": "baseline", - "probabilities": [[0.45, 0.55], [0.90, 0.10]], - "predictions": [1, 0], - "coordinates": [[0, 0, 2, 2], [2, 2, 4, 4]], - } - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - ) - _merged = np.array([[2, 2, 0, 0], [2, 2, 0, 0], [0, 0, 1, 1], [0, 0, 1, 1]]) - assert np.sum(merged - _merged) == 0 - # blind test for merging probabilities - merged = PatchPredictor.merge_predictions( - np.zeros([4, 4]), - output, - resolution=1.0, - units="baseline", - return_raw=True, - ) - _merged = np.array( - [ - [0.45, 0.45, 0, 0], - [0.45, 0.45, 0, 0], - [0, 0, 0.90, 0.90], - [0, 0, 0.90, 0.90], - ], - ) - assert merged.shape == (4, 4, 2) - assert np.mean(np.abs(merged[..., 0] - _merged)) < 1.0e-6 + wsi_pred = zarr.open(str(output[mini_wsi_svs]), mode="r") + tile_pred = zarr.open(str(output[mini_wsi_jpg]), mode="r") + diff = tile_pred["probabilities"][:] == wsi_pred["probabilities"][:] + accuracy = np.sum(diff) / np.size(wsi_pred["probabilities"][:]) + assert accuracy > 0.99, np.nonzero(~diff) - # integration test - predictor = PatchPredictor(pretrained_model="resnet18-kather100k", batch_size=1) - - kwargs = { - "return_probabilities": True, - "return_labels": True, - "on_gpu": ON_GPU, - "patch_input_shape": np.array([224, 224]), - "stride_shape": np.array([224, 224]), - "resolution": 1.0, - "units": "baseline", - "merge_predictions": True, - } - # sanity check, both output should be the same with same resolution read args - wsi_output = predictor.predict( - [mini_wsi_svs], - masks=[mini_wsi_msk], - mode="wsi", - **kwargs, - ) - - # mock up to change the preproc func and - # force to use the default in merge function - # still should have the same results - kwargs["merge_predictions"] = False - tile_output = predictor.predict( - [mini_wsi_jpg], - masks=[mini_wsi_msk], - mode="tile", - **kwargs, - ) - merged_tile_output = predictor.merge_predictions( - mini_wsi_jpg, - tile_output[0], - resolution=kwargs["resolution"], - units=kwargs["units"], - ) - tile_output.append(merged_tile_output) - - # first make sure nothing breaks with predictions - wpred = np.array(wsi_output[0]["predictions"]) - tpred = np.array(tile_output[0]["predictions"]) - diff = tpred == wpred - accuracy = np.sum(diff) / np.size(wpred) - assert accuracy > 0.9, np.nonzero(~diff) - - merged_wsi = wsi_output[1] - merged_tile = tile_output[1] - # ensure shape of merged predictions of tile and wsi input are the same - assert merged_wsi.shape == merged_tile.shape - # ensure consistent predictions between tile and wsi mode - diff = merged_tile == merged_wsi - accuracy = np.sum(diff) / np.size(merged_wsi) - assert accuracy > 0.9, np.nonzero(~diff) + shutil.rmtree(_kwargs["save_dir"], ignore_errors=True) def _test_predictor_output( inputs: list, - pretrained_model: str, + model: str, probabilities_check: list | None = None, predictions_check: list | None = None, - *, - on_gpu: bool = ON_GPU, ) -> None: """Test the predictions of multiple models included in tiatoolbox.""" predictor = PatchPredictor( - pretrained_model=pretrained_model, + model=model, batch_size=32, verbose=False, ) # don't run test on GPU - output = predictor.predict( + output = predictor.run( inputs, return_probabilities=True, return_labels=False, - on_gpu=on_gpu, + device=device, ) - predictions = output["predictions"] - probabilities = output["probabilities"] - for idx, probabilities_ in enumerate(probabilities): + predictions = output["probabilities"] + for idx, probabilities_ in enumerate(predictions): probabilities_max = max(probabilities_) assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, ( - pretrained_model, + model, probabilities_max, probabilities_check[idx], - predictions[idx], + probabilities_, predictions_check[idx], ) - assert predictions[idx] == predictions_check[idx], ( - pretrained_model, + assert np.argmax(probabilities_) == predictions_check[idx], ( + model, probabilities_max, probabilities_check[idx], - predictions[idx], + probabilities_, predictions_check[idx], ) @@ -1018,52 +302,188 @@ def test_patch_predictor_kather100k_output( "mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209], "googlenet-kather100k": [1.0, 0.9999639987945557], } - for pretrained_model, expected_prob in pretrained_info.items(): + for model, expected_prob in pretrained_info.items(): _test_predictor_output( inputs, - pretrained_model, + model, probabilities_check=expected_prob, predictions_check=[6, 3], - on_gpu=ON_GPU, ) # only test 1 on travis to limit runtime if toolbox_env.running_on_ci(): break -def test_patch_predictor_pcam_output(sample_patch3: Path, sample_patch4: Path) -> None: - """Test the output of patch prediction models on PCam dataset.""" - inputs = [Path(sample_patch3), Path(sample_patch4)] - pretrained_info = { - "alexnet-pcam": [0.999980092048645, 0.9769067168235779], - "resnet18-pcam": [0.999992847442627, 0.9466130137443542], - "resnet34-pcam": [1.0, 0.9976525902748108], - "resnet50-pcam": [0.9999270439147949, 0.9999996423721313], - "resnet101-pcam": [1.0, 0.9997289776802063], - "resnext50_32x4d-pcam": [0.9999996423721313, 0.9984435439109802], - "resnext101_32x8d-pcam": [0.9997072815895081, 0.9969086050987244], - "wide_resnet50_2-pcam": [0.9999837875366211, 0.9959040284156799], - "wide_resnet101_2-pcam": [1.0, 0.9945427179336548], - "densenet121-pcam": [0.9999251365661621, 0.9997479319572449], - "densenet161-pcam": [0.9999969005584717, 0.9662821292877197], - "densenet169-pcam": [0.9999998807907104, 0.9993504881858826], - "densenet201-pcam": [0.9999942779541016, 0.9950824975967407], - "mobilenet_v2-pcam": [0.9999876022338867, 0.9942564368247986], - "mobilenet_v3_large-pcam": [0.9999922513961792, 0.9719613790512085], - "mobilenet_v3_small-pcam": [0.9999963045120239, 0.9747149348258972], - "googlenet-pcam": [0.9999929666519165, 0.8701475858688354], +def _validate_probabilities(predictions: list | dict) -> bool: + """Helper function to test if the probabilities value are valid.""" + if isinstance(predictions, dict): + return all(0 <= probability <= 1 for _, probability in predictions.items()) + + for row in predictions: + for element in row: + if not (0 <= element <= 1): + return False + return True + + +def test_wsi_predictor_zarr(sample_wsi_dict: dict, tmp_path: Path) -> None: + """Test normal run of patch predictor for WSIs.""" + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + + predictor = PatchPredictor( + model="alexnet-kather100k", + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = predictor.run( + images=[mini_wsi_svs], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check", + ) + + assert output[mini_wsi_svs].exists() + + output_ = zarr.open(output[mini_wsi_svs]) + + assert output_["probabilities"].shape == (70, 9) # number of patches x classes + assert output_["probabilities"].ndim == 2 + # number of patches x [start_x, start_y, end_x, end_y] + assert output_["coordinates"].shape == (70, 4) + assert output_["coordinates"].ndim == 2 + assert _validate_probabilities(predictions=output_["probabilities"]) + + +def test_wsi_predictor_zarr_baseline(sample_wsi_dict: dict, tmp_path: Path) -> None: + """Test normal run of patch predictor for WSIs.""" + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + + predictor = PatchPredictor( + model="alexnet-kather100k", + batch_size=32, + verbose=False, + ) + # don't run test on GPU + output = predictor.run( + images=[mini_wsi_svs], + return_probabilities=True, + return_labels=False, + device=device, + patch_mode=False, + save_dir=tmp_path / "wsi_out_check", + units="baseline", + resolution=1.0, + ) + + assert output[mini_wsi_svs].exists() + + output_ = zarr.open(output[mini_wsi_svs]) + + assert output_["probabilities"].shape == (244, 9) # number of patches x classes + assert output_["probabilities"].ndim == 2 + # number of patches x [start_x, start_y, end_x, end_y] + assert output_["coordinates"].shape == (244, 4) + assert output_["coordinates"].ndim == 2 + assert _validate_probabilities(predictions=output_["probabilities"]) + + +def _extract_probabilities_from_annotation_store(dbfile: str) -> dict: + """Helper function to extract probabilities from Annotation Store.""" + probs_dict = {} + con = sqlite3.connect(dbfile) + cur = con.cursor() + annotations_properties = list(cur.execute("SELECT properties FROM annotations")) + + for item in annotations_properties: + for json_str in item: + probs_dict = json.loads(json_str) + probs_dict.pop("prob_0") + + return probs_dict + + +def test_engine_run_wsi_annotation_store( + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: + """Test the engine run for Whole slide images.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + eng = PatchPredictor(model="alexnet-kather100k") + + patch_size = np.array([224, 224]) + save_dir = f"{tmp_path}/model_wsi_output" + + kwargs = { + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 0.5, + "save_dir": save_dir, + "units": "mpp", + "scale_factor": (2.0, 2.0), } - for pretrained_model, expected_prob in pretrained_info.items(): - _test_predictor_output( - inputs, - pretrained_model, - probabilities_check=expected_prob, - predictions_check=[1, 0], - on_gpu=ON_GPU, - ) - # only test 1 on travis to limit runtime - if toolbox_env.running_on_ci(): - break + + output = eng.run( + images=[mini_wsi_svs], + masks=[mini_wsi_msk], + patch_mode=False, + output_type="AnnotationStore", + **kwargs, + ) + + output_ = output[mini_wsi_svs] + + assert output_.exists() + assert output_.suffix == ".db" + predictions = _extract_probabilities_from_annotation_store(output_) + assert _validate_probabilities(predictions) + + shutil.rmtree(save_dir) + + +def test_engine_run_wsi_annotation_store_power( + sample_wsi_dict: dict, + tmp_path: Path, +) -> None: + """Test the engine run for Whole slide images.""" + # convert to pathlib Path to prevent wsireader complaint + mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"]) + mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"]) + + eng = PatchPredictor(model="alexnet-kather100k") + + patch_size = np.array([224, 224]) + save_dir = f"{tmp_path}/model_wsi_output" + + kwargs = { + "patch_input_shape": patch_size, + "stride_shape": patch_size, + "resolution": 20, + "save_dir": save_dir, + "units": "power", + } + + output = eng.run( + images=[mini_wsi_svs], + masks=[mini_wsi_msk], + patch_mode=False, + output_type="AnnotationStore", + **kwargs, + ) + + output_ = output[mini_wsi_svs] + + assert output_.exists() + assert output_.suffix == ".db" + predictions = _extract_probabilities_from_annotation_store(output_) + assert _validate_probabilities(predictions) + + shutil.rmtree(save_dir) # ------------------------------------------------------------------------------------- @@ -1103,14 +523,14 @@ def test_command_line_models_incorrect_mode(sample_svs: Path, tmp_path: Path) -> str(sample_svs), "--file-types", '"*.ndpi, *.svs"', - "--mode", + "--patch-mode", '"patch"', "--output-path", str(tmp_path.joinpath("output")), ], ) - assert "Invalid value for '--mode'" in mode_not_in_wsi_tile_result.output + assert "Invalid value for '--patch-mode'" in mode_not_in_wsi_tile_result.output assert mode_not_in_wsi_tile_result.exit_code != 0 assert isinstance(mode_not_in_wsi_tile_result.exception, SystemExit) @@ -1124,47 +544,15 @@ def test_cli_model_single_file(sample_svs: Path, tmp_path: Path) -> None: "patch-predictor", "--img-input", str(sample_svs), - "--mode", - "wsi", + "--patch-mode", + "False", "--output-path", - str(tmp_path.joinpath("output")), + str(tmp_path / "output"), ], ) assert models_wsi_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() - - -def test_cli_model_single_file_mask(remote_sample: Callable, tmp_path: Path) -> None: - """Test for models CLI single file with mask.""" - mini_wsi_svs = Path(remote_sample("svs-1-small")) - sample_wsi_msk = remote_sample("small_svs_tissue_mask") - sample_wsi_msk = np.load(sample_wsi_msk).astype(np.uint8) - imwrite(f"{tmp_path}/small_svs_tissue_mask.jpg", sample_wsi_msk) - sample_wsi_msk = f"{tmp_path}/small_svs_tissue_mask.jpg" - - runner = CliRunner() - models_tiles_result = runner.invoke( - cli.main, - [ - "patch-predictor", - "--img-input", - str(mini_wsi_svs), - "--mode", - "wsi", - "--masks", - str(sample_wsi_msk), - "--output-path", - str(tmp_path.joinpath("output")), - ], - ) - - assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("output/0.merged.npy").exists() - assert tmp_path.joinpath("output/0.raw.json").exists() - assert tmp_path.joinpath("output/results.json").exists() + assert (tmp_path / "output" / (sample_svs.stem + ".db")).exists() def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -> None: @@ -1187,20 +575,18 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) - dir_path.joinpath("2_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) dir_path.joinpath("3_" + mini_wsi_svs.name).symlink_to(mini_wsi_svs) except OSError: - shutil.copy(mini_wsi_svs, dir_path.joinpath("1_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("2_" + mini_wsi_svs.name)) - shutil.copy(mini_wsi_svs, dir_path.joinpath("3_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path / ("1_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path / ("2_" + mini_wsi_svs.name)) + shutil.copy(mini_wsi_svs, dir_path / ("3_" + mini_wsi_svs.name)) try: dir_path_masks.joinpath("1_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) dir_path_masks.joinpath("2_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) dir_path_masks.joinpath("3_" + mini_wsi_msk.name).symlink_to(mini_wsi_msk) except OSError: - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("1_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("2_" + mini_wsi_msk.name)) - shutil.copy(mini_wsi_msk, dir_path_masks.joinpath("3_" + mini_wsi_msk.name)) - - tmp_path = tmp_path.joinpath("output") + shutil.copy(mini_wsi_msk, dir_path_masks / ("1_" + mini_wsi_msk.name)) + shutil.copy(mini_wsi_msk, dir_path_masks / ("2_" + mini_wsi_msk.name)) + shutil.copy(mini_wsi_msk, dir_path_masks / ("3_" + mini_wsi_msk.name)) runner = CliRunner() models_tiles_result = runner.invoke( @@ -1209,20 +595,18 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) - "patch-predictor", "--img-input", str(dir_path), - "--mode", - "wsi", + "--patch-mode", + str(False), "--masks", str(dir_path_masks), "--output-path", - str(tmp_path), + str(tmp_path / "output"), + "--output-type", + "zarr", ], ) assert models_tiles_result.exit_code == 0 - assert tmp_path.joinpath("0.merged.npy").exists() - assert tmp_path.joinpath("0.raw.json").exists() - assert tmp_path.joinpath("1.merged.npy").exists() - assert tmp_path.joinpath("1.raw.json").exists() - assert tmp_path.joinpath("2.merged.npy").exists() - assert tmp_path.joinpath("2.raw.json").exists() - assert tmp_path.joinpath("results.json").exists() + assert (tmp_path / "output" / ("1_" + mini_wsi_svs.stem + ".zarr")).exists() + assert (tmp_path / "output" / ("2_" + mini_wsi_svs.stem + ".zarr")).exists() + assert (tmp_path / "output" / ("3_" + mini_wsi_svs.stem + ".zarr")).exists() diff --git a/tests/models/test_dataset.py b/tests/models/test_dataset.py index de5b726a7..ab9a6033f 100644 --- a/tests/models/test_dataset.py +++ b/tests/models/test_dataset.py @@ -309,7 +309,7 @@ def test_patch_dataset_crash(tmp_path: Path) -> None: save_dir_path / "sample2.npy", ] with pytest.raises( - ValueError, + TypeError, match="Cannot load image data from", ): _ = PatchDataset(imgs) diff --git a/tests/test_annotation_stores.py b/tests/test_annotation_stores.py index 01bbdac45..66c990161 100644 --- a/tests/test_annotation_stores.py +++ b/tests/test_annotation_stores.py @@ -53,14 +53,6 @@ FILLED_LEN = 2 * (GRID_SIZE[0] * GRID_SIZE[1]) RNG = np.random.default_rng(0) # Numpy Random Generator -# ---------------------------------------------------------------------- -# Resets -# ---------------------------------------------------------------------- - -# Reset filters in logger. -for filter_ in logger.filters: - logger.removeFilter(filter_) - # ---------------------------------------------------------------------- # Helper Functions # ---------------------------------------------------------------------- @@ -546,6 +538,9 @@ def test_sqlite_store_compile_options_missing_math( caplog: pytest.LogCaptureFixture, ) -> None: """Test that a warning is shown if the sqlite math module is missing.""" + # Reset filters in logger. + for filter_ in logger.filters[:]: + logger.removeFilter(filter_) monkeypatch.setattr( SQLiteStore, "compile_options", diff --git a/tests/test_init.py b/tests/test_init.py index 509a9c49f..6d8ed8238 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -114,7 +114,7 @@ def test_duplicate_filter(caplog: pytest.LogCaptureFixture) -> None: logger.addFilter(duplicate_filter) # Reset filters in logger. - for filter_ in logger.filters: + for filter_ in logger.filters[:]: logger.removeFilter(filter_) for _ in range(2): diff --git a/tests/test_utils.py b/tests/test_utils.py index ad8e0e3da..0b1f3f484 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1646,6 +1646,7 @@ def test_patch_pred_store() -> None: """Test patch_pred_store.""" # Define a mock patch_output patch_output = { + "probabilities": [(0.99, 0.01), (0.01, 0.99), (0.99, 0.01)], "predictions": [1, 0, 1], "coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)], "other": "other", @@ -1680,7 +1681,7 @@ def test_patch_pred_store_cdict() -> None: class_dict = {0: "class0", 1: "class1"} store = misc.dict_to_store(patch_output, (1.0, 1.0), class_dict=class_dict) - # Check that its an SQLiteStore containing the expected annotations + # Check that it is an SQLiteStore containing the expected annotations assert isinstance(store, SQLiteStore) assert len(store) == 3 for annotation in store.values(): diff --git a/tiatoolbox/cli/common.py b/tiatoolbox/cli/common.py index 81ba7b5f4..6b4a23a5a 100644 --- a/tiatoolbox/cli/common.py +++ b/tiatoolbox/cli/common.py @@ -86,6 +86,24 @@ def cli_file_type( ) +def cli_output_type( + usage_help: str = "The format of the output type. " + "'output_type' can be 'zarr' or 'AnnotationStore'. " + "Default value is 'AnnotationStore'.", + default: str = "AnnotationStore", + input_type: click.Choice | None = None, +) -> callable: + """Enables --file-types option for cli.""" + if input_type is None: + input_type = click.Choice(["zarr", "AnnotationStore"], case_sensitive=False) + return click.option( + "--output-type", + help=add_default_to_usage_help(usage_help, default), + default=default, + type=input_type, + ) + + def cli_mode( usage_help: str = "Selected mode to show or save the required information.", default: str = "save", @@ -102,6 +120,20 @@ def cli_mode( ) +def cli_patch_mode( + usage_help: str = "Whether to run the model in patch mode or WSI mode.", + *, + default: bool = False, +) -> callable: + """Enables --return-probabilities option for cli.""" + return click.option( + "--patch-mode", + type=bool, + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + def cli_region( usage_help: str = "Image region in the whole slide image to read from. " "default=0 0 2000 2000", @@ -215,7 +247,7 @@ def cli_pretrained_model( ) -> callable: """Enables --pretrained-model option for cli.""" return click.option( - "--pretrained-model", + "--model", help=add_default_to_usage_help(usage_help, default), default=default, ) @@ -234,6 +266,51 @@ def cli_pretrained_weights( ) +def cli_model( + usage_help: str = "Name of the predefined model used to process the data. " + "The format is _. For example, " + "`resnet18-kather100K` is a resnet18 model trained on the Kather dataset. " + "Please see " + "https://tia-toolbox.readthedocs.io/en/latest/usage.html#deep-learning-models " + "for a detailed list of available pretrained models." + "By default, the corresponding pretrained weights will also be" + "downloaded. However, you can override with your own set of weights" + "via the `pretrained_weights` argument. Argument is case insensitive.", + default: str = "resnet18-kather100k", +) -> callable: + """Enables --pretrained-model option for cli.""" + return click.option( + "--model", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + +def cli_weights( + usage_help: str = "Path to the model weight file. If not supplied, the default " + "pretrained weight will be used.", + default: str | None = None, +) -> callable: + """Enables --pretrained-weights option for cli.""" + return click.option( + "--weights", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + +def cli_device( + usage_help: str = "Select the device (cpu/cuda/mps) to use for inference.", + default: str = "cpu", +) -> callable: + """Enables --pretrained-weights option for cli.""" + return click.option( + "--device", + help=add_default_to_usage_help(usage_help, default), + default=default, + ) + + def cli_return_probabilities( usage_help: str = "Whether to return raw model probabilities.", *, diff --git a/tiatoolbox/cli/patch_predictor.py b/tiatoolbox/cli/patch_predictor.py index f6cc1b397..263809146 100644 --- a/tiatoolbox/cli/patch_predictor.py +++ b/tiatoolbox/cli/patch_predictor.py @@ -2,25 +2,21 @@ from __future__ import annotations -import click - from tiatoolbox.cli.common import ( cli_batch_size, + cli_device, cli_file_type, cli_img_input, cli_masks, - cli_merge_predictions, - cli_mode, + cli_model, cli_num_loader_workers, - cli_on_gpu, cli_output_path, - cli_pretrained_model, - cli_pretrained_weights, + cli_output_type, + cli_patch_mode, cli_resolution, - cli_return_labels, - cli_return_probabilities, cli_units, cli_verbose, + cli_weights, prepare_model_cli, tiatoolbox_cli, ) @@ -35,45 +31,36 @@ @cli_file_type( default="*.png, *.jpg, *.jpeg, *.tif, *.tiff, *.svs, *.ndpi, *.jp2, *.mrxs", ) -@cli_mode( - usage_help="Type of input file to process.", - default="wsi", - input_type=click.Choice(["patch", "wsi", "tile"], case_sensitive=False), -) -@cli_pretrained_model(default="resnet18-kather100k") -@cli_pretrained_weights() -@cli_return_probabilities(default=False) -@cli_merge_predictions(default=True) -@cli_return_labels(default=True) -@cli_on_gpu(default=False) +@cli_patch_mode(default=False) +@cli_model(default="resnet18-kather100k") +@cli_weights() +@cli_device(default="cpu") @cli_batch_size(default=1) @cli_resolution(default=0.5) @cli_units(default="mpp") @cli_masks(default=None) @cli_num_loader_workers(default=0) +@cli_output_type(default="AnnotationStore") @cli_verbose(default=True) def patch_predictor( - pretrained_model: str, - pretrained_weights: str, + model: str, + weights: str, img_input: str, file_types: str, masks: str | None, - mode: str, output_path: str, batch_size: int, resolution: float, units: str, num_loader_workers: int, + device: str, + output_type: str, *, - return_probabilities: bool, - return_labels: bool, - merge_predictions: bool, - on_gpu: bool, + patch_mode: bool, verbose: bool, ) -> None: """Process an image/directory of input images with a patch classification CNN.""" - from tiatoolbox.models import PatchPredictor - from tiatoolbox.utils import save_as_json + from tiatoolbox.models.engine.patch_predictor import PatchPredictor files_all, masks_all, output_path = prepare_model_cli( img_input=img_input, @@ -83,26 +70,21 @@ def patch_predictor( ) predictor = PatchPredictor( - pretrained_model=pretrained_model, - weights=pretrained_weights, + model=model, + weights=weights, batch_size=batch_size, num_loader_workers=num_loader_workers, verbose=verbose, ) - output = predictor.predict( - imgs=files_all, + _ = predictor.run( + images=files_all, masks=masks_all, - mode=mode, - return_probabilities=return_probabilities, - merge_predictions=merge_predictions, - labels=None, - return_labels=return_labels, + patch_mode=patch_mode, resolution=resolution, units=units, - on_gpu=on_gpu, + device=device, save_dir=output_path, save_output=True, + output_type=output_type, ) - - save_as_json(output, str(output_path.joinpath("results.json"))) diff --git a/tiatoolbox/models/architecture/vanilla.py b/tiatoolbox/models/architecture/vanilla.py index 5c19f4c27..e7b956411 100644 --- a/tiatoolbox/models/architecture/vanilla.py +++ b/tiatoolbox/models/architecture/vanilla.py @@ -169,7 +169,7 @@ def infer_batch( with torch.inference_mode(): output = model(img_patches_device) # Output should be a single tensor or scalar - return {"predictions": output.cpu().numpy()} + return {"probabilities": output.cpu().numpy()} class CNNBackbone(ModelABC): @@ -265,5 +265,6 @@ def infer_batch( # Do not compute the gradient (not training) with torch.inference_mode(): output = model(img_patches_device) + # Output should be a single tensor or scalar - return {"predictions": output.cpu().numpy()} + return {"probabilities": output.cpu().numpy()} diff --git a/tiatoolbox/models/dataset/dataset_abc.py b/tiatoolbox/models/dataset/dataset_abc.py index d634ccbd4..045bb39b7 100644 --- a/tiatoolbox/models/dataset/dataset_abc.py +++ b/tiatoolbox/models/dataset/dataset_abc.py @@ -145,7 +145,7 @@ def load_img(path: str | Path) -> np.ndarray: if path.suffix not in (".npy", ".jpg", ".jpeg", ".tif", ".tiff", ".png"): msg = f"Cannot load image data from `{path.suffix}` files." - raise ValueError(msg) + raise TypeError(msg) return imread(path, as_uint8=False) @@ -399,10 +399,8 @@ def __init__( # skipcq: PY-R1000 # noqa: PLR0915 `units`. Expected to be positive and of (height, width). Note, this is not at level 0. resolution (Resolution): - Check (:class:`.WSIReader`) for details. When - `mode='tile'`, value is fixed to be `resolution=1.0` and - `units='baseline'` units: check (:class:`.WSIReader`) for - details. + Requested resolution corresponding to units. Check + (:class:`WSIReader`) for details. units (Units): Units in which `resolution` is defined. auto_get_mask (bool): diff --git a/tiatoolbox/models/engine/engine_abc.py b/tiatoolbox/models/engine/engine_abc.py index 8fef0c4e2..465230116 100644 --- a/tiatoolbox/models/engine/engine_abc.py +++ b/tiatoolbox/models/engine/engine_abc.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy +import shutil from abc import ABC, abstractmethod from pathlib import Path from typing import TYPE_CHECKING, TypedDict @@ -14,7 +15,7 @@ from torch import nn from typing_extensions import Unpack -from tiatoolbox import logger +from tiatoolbox import DuplicateFilter, logger from tiatoolbox.models.architecture import get_pretrained_model from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset from tiatoolbox.models.models_abc import load_torch_model @@ -128,7 +129,8 @@ class EngineABCRunParams(TypedDict, total=False): num_post_proc_workers (int): Number of workers to postprocess the results of the model. output_file (str): - Output file name to save "zarr" or "db". + Output file name to save "zarr" or "db". If None, path to output is + returned by the engine. patch_input_shape (tuple): Shape of patches input to the model as tuple of height and width (HW). Patches are requested at read resolution, not with respect to level 0, @@ -355,8 +357,6 @@ def __init__( verbose: bool = False, ) -> None: """Initialize Engine.""" - super().__init__() - self.images = None self.masks = None self.patch_mode = None @@ -378,10 +378,10 @@ def __init__( self.num_loader_workers = num_loader_workers self.num_post_proc_workers = num_post_proc_workers self.patch_input_shape: IntPair | None = None - self.resolution: Resolution = 1.0 + self.resolution: Resolution | None = None self.return_labels: bool = False self.stride_shape: IntPair | None = None - self.units: Units = "baseline" + self.units: Units | None = None self.verbose = verbose @staticmethod @@ -440,7 +440,7 @@ def _initialize_model_ioconfig( def get_dataloader( self: EngineABC, - images: Path, + images: str | Path | list[str | Path] | np.ndarray, masks: Path | None = None, labels: list | None = None, ioconfig: ModelIOConfigABC | None = None, @@ -453,7 +453,7 @@ def get_dataloader( images (list of str or :class:`Path` or :class:`numpy.ndarray`): A list of image patches in NHWC format as a numpy array or a list of str/paths to WSIs. When `patch_mode` is False - the function expects list of str/paths to WSIs. + the function expects path to a single WSI. masks (list | None): List of masks. Only utilised when patch_mode is False. Patches are only generated within a masked area. @@ -470,7 +470,6 @@ def get_dataloader( torch.utils.data.DataLoader: :class:`torch.utils.data.DataLoader` for inference. - """ if labels: # if a labels is provided, then return with the prediction @@ -527,6 +526,8 @@ def infer_patches( self: EngineABC, dataloader: DataLoader, save_path: Path | None, + *, + return_coordinates: bool = False, ) -> dict | Path: """Runs model inference on image patches and returns output as a dictionary. @@ -535,6 +536,9 @@ def infer_patches( An :class:`torch.utils.data.DataLoader` object to run inference. save_path (Path | None): If `cache_mode` is True then path to save zarr file must be provided. + return_coordinates (bool): + Whether to save coordinates in the output. This is required when + this function is called by `infer_wsi` and `patch_mode` is False. Returns: dict or Path: @@ -553,11 +557,14 @@ def infer_patches( position=0, ) - keys = ["predictions"] + keys = ["probabilities"] if self.return_labels: keys.append("labels") + if return_coordinates: + keys.append("coordinates") + raw_predictions = {key: None for key in keys} zarr_group = None @@ -571,9 +578,14 @@ def infer_patches( batch_data["image"], device=self.device, ) + if return_coordinates: + batch_output["coordinates"] = batch_data["coords"].numpy() if self.return_labels: # be careful of `s` - batch_output["labels"] = batch_data["label"].numpy() + if isinstance(batch_data["label"], torch.Tensor): + batch_output["labels"] = batch_data["label"].numpy() + else: + batch_output["labels"] = batch_data["label"] raw_predictions = self._update_model_output( raw_predictions=raw_predictions, @@ -597,7 +609,7 @@ def infer_patches( def post_process_patches( self: EngineABC, raw_predictions: dict | Path, - **kwargs: dict, + **kwargs: Unpack[EngineABCRunParams], ) -> dict | Path: """Post-process raw patch predictions from inference. @@ -609,8 +621,9 @@ def post_process_patches( Args: raw_predictions (dict | Path): A dictionary or path to zarr with patch prediction information. - **kwargs (dict): - Keyword Args to update setup_patch_dataset() method attributes. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`EngineRunParams` for accepted keyword arguments. Returns: dict or Path: @@ -618,7 +631,7 @@ def post_process_patches( saved zarr file if `cache_mode` is True. """ - _ = kwargs.get("predictions") # Key values required for post-processing + _ = kwargs.get("probabilities") # Key values required for post-processing if self.cache_mode: # cache mode _ = zarr.open(raw_predictions, mode="w") @@ -627,7 +640,7 @@ def post_process_patches( def save_predictions( self: EngineABC, - processed_predictions: dict, + processed_predictions: dict | Path, output_type: str, save_dir: Path | None = None, **kwargs: dict, @@ -656,26 +669,36 @@ def save_predictions( `.zarr` file depending on whether a save_dir Path is provided. """ - if (self.cache_mode or not save_dir) and output_type != "AnnotationStore": + if ( + self.cache_mode or not save_dir + ) and output_type.lower() != "annotationstore": return processed_predictions - output_file = Path(kwargs.get("output_file", "output.db")) - - save_path = save_dir / output_file + save_path = Path(kwargs.get("output_file", save_dir / "output.db")) - if output_type == "AnnotationStore": + if output_type.lower() == "annotationstore": # scale_factor set from kwargs - scale_factor = kwargs.get("scale_factor") + scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) # class_dict set from kwargs class_dict = kwargs.get("class_dict") + processed_predictions_path: str | Path | None = None + # Need to add support for zarr conversion. - return dict_to_store( + if self.cache_mode: + processed_predictions_path = processed_predictions + processed_predictions = zarr.open(processed_predictions, mode="r") + + out_file = dict_to_store( processed_predictions, scale_factor, class_dict, save_path, ) + if processed_predictions_path is not None: + shutil.rmtree(processed_predictions_path) + + return out_file return ( dict_to_zarr( @@ -691,11 +714,9 @@ def save_predictions( def infer_wsi( self: EngineABC, dataloader: torch.utils.data.DataLoader, - img_label: str, - highest_input_resolution: list[dict], - save_dir: Path, + save_path: Path | str, **kwargs: dict, - ) -> list: + ) -> dict | Path: """Model inference on a WSI. This function must be implemented by subclasses. @@ -704,22 +725,28 @@ def infer_wsi( # return coordinates of patches processed within a tile / whole-slide image raise NotImplementedError + @abstractmethod + def post_process_wsi( + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | Path: + """Post process WSI output.""" + _ = kwargs.get("probabilities") # Key values required for post-processing + return raw_predictions + @abstractmethod def save_wsi_output( self: EngineABC, - raw_output: dict | Path, - save_dir: Path, + processed_output: Path, output_type: str, **kwargs: Unpack[EngineABCRunParams], - ) -> AnnotationStore | Path: - """Post-process a WSI. + ) -> Path: + """Aggregate the output at the WSI level and save to file. Args: - raw_output (dict | Path): - A dictionary with output information or zarr file path. - save_dir (Path): - Output Path to directory to save the patch dataset output to a - `.zarr` or `.db` file + processed_output (Path): + Path to Zarr file with intermediate results. output_type (str): The desired output type for resulting patch dataset. **kwargs (EngineABCRunParams): @@ -732,23 +759,22 @@ def save_wsi_output( stored in a `.zarr` file. """ - if ( - output_type == "zarr" - and isinstance(raw_output, Path) - and raw_output.suffix == ".zarr" - ): - return raw_output + if output_type.lower() == "zarr": + msg = "Output file saved at %s.", processed_output + logger.info(msg=msg) + return processed_output - output_file = kwargs.get("output_file", "output") - save_path = save_dir / output_file - - if output_type == "AnnotationStore": + if output_type.lower() == "annotationstore": + save_path = Path(kwargs.get("output_file", processed_output.stem + ".db")) # scale_factor set from kwargs scale_factor = kwargs.get("scale_factor", (1.0, 1.0)) + # Read zarr file to a dict + raw_output_dict = zarr.open(str(processed_output), mode="r") + # class_dict set from kwargs class_dict = kwargs.get("class_dict") - return dict_to_store(raw_output, scale_factor, class_dict, save_path) + return dict_to_store(raw_output_dict, scale_factor, class_dict, save_path) msg = "Only supports zarr and AnnotationStore as output_type." raise ValueError(msg) @@ -778,7 +804,7 @@ def _load_ioconfig(self: EngineABC, ioconfig: ModelIOConfigABC) -> ModelIOConfig "Please provide a valid ModelIOConfigABC. " "No default ModelIOConfigABC found." ) - raise ValueError(msg) + logger.warning(msg) if ioconfig and isinstance(ioconfig, ModelIOConfigABC): self.ioconfig = ioconfig @@ -914,6 +940,7 @@ def _update_run_params( labels: list | None = None, save_dir: os | Path | None = None, ioconfig: ModelIOConfigABC | None = None, + output_type: str = "dict", *, overwrite: bool = False, patch_mode: bool, @@ -928,10 +955,17 @@ def _update_run_params( setattr(self, key, kwargs.get(key)) self.patch_mode = patch_mode + if not self.patch_mode: + self.cache_mode = True # if input is WSI run using cache mode. + if self.cache_mode and self.batch_size > self.cache_size: self.batch_size = self.cache_size self._validate_input_numbers(images=images, masks=masks, labels=labels) + if output_type.lower() not in ["dict", "zarr", "annotationstore"]: + msg = "output_type must be 'dict' or 'zarr' or 'annotationstore'." + raise TypeError(msg) + self.images = self._validate_images_masks(images=images) if masks is not None: @@ -966,11 +1000,15 @@ def _run_patch_mode( """ save_path = None if self.cache_mode: - output_file = Path(kwargs.get("output_file", "output.db")) + output_file = Path(kwargs.get("output_file", "output.zarr")) save_path = save_dir / (str(output_file.stem) + ".zarr") + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + dataloader = self.get_dataloader( images=self.images, + masks=self.masks, labels=self.labels, patch_mode=True, ) @@ -982,6 +1020,8 @@ def _run_patch_mode( raw_predictions=raw_predictions, **kwargs, ) + logger.removeFilter(duplicate_filter) + return self.save_predictions( processed_predictions=processed_predictions, output_type=output_type, @@ -989,6 +1029,106 @@ def _run_patch_mode( **kwargs, ) + @staticmethod + def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, float]: + """Calculates scale factor for final output. + + Uses the dataloader resolution and the WSI resolution to calculate scale + factor for final WSI output. + + Args: + dataloader (DataLoader): + Dataloader for the current run. + + Returns: + scale_factor (float | tuple[float, float]): + Scale factor for final output. + + """ + # get units and resolution from dataloader. + dataloader_units = dataloader.dataset.units + dataloader_resolution = dataloader.dataset.resolution + + # if dataloader units is baseline slide resolution is 1.0. + # in this case dataloader resolution / slide resolution will be + # equal to dataloader resolution. + + if dataloader_units in ["mpp", "level", "power"]: + wsimeta_dict = dataloader.dataset.reader.info.as_dict() + + if dataloader_units == "mpp": + slide_resolution = wsimeta_dict[dataloader_units] + scale_factor = np.divide(slide_resolution, dataloader_resolution) + return scale_factor[0], scale_factor[1] + + if dataloader_units == "level": + downsample_ratio = wsimeta_dict["level_downsamples"][dataloader_resolution] + return 1.0 / downsample_ratio, 1.0 / downsample_ratio + + if dataloader_units == "power": + slide_objective_power = wsimeta_dict["objective_power"] + return ( + dataloader_resolution / slide_objective_power, + dataloader_resolution / slide_objective_power, + ) + + return dataloader_resolution + + def _run_wsi_mode( + self: EngineABC, + output_type: str, + save_dir: Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | AnnotationStore | Path: + """Runs the Engine in the WSI mode (patch_mode = False). + + Input arguments are passed from :func:`EngineABC.run()`. + + """ + suffix = ".zarr" + if output_type == "AnnotationStore": + suffix = ".db" + + out = {image: save_dir / (str(image.stem) + suffix) for image in self.images} + + save_path = { + image: save_dir / (str(image.stem) + ".zarr") for image in self.images + } + + for image_num, image in enumerate(self.images): + duplicate_filter = DuplicateFilter() + logger.addFilter(duplicate_filter) + mask = self.masks[image_num] if self.masks is not None else None + dataloader = self.get_dataloader( + images=image, + masks=mask, + patch_mode=False, + ioconfig=self._ioconfig, + ) + + scale_factor = self._calculate_scale_factor(dataloader=dataloader) + + raw_predictions = self.infer_wsi( + dataloader=dataloader, + save_path=save_path[image], + **kwargs, + ) + processed_predictions = self.post_process_wsi( + raw_predictions=raw_predictions, + **kwargs, + ) + kwargs["output_file"] = out[image] + kwargs["scale_factor"] = scale_factor + out[image] = self.save_predictions( + processed_predictions=processed_predictions, + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) + logger.removeFilter(duplicate_filter) + + return out + def run( self: EngineABC, images: list[os | Path | WSIReader] | np.ndarray, @@ -1031,7 +1171,7 @@ def run( Whether to overwrite the results. Default = False. output_type (str): The format of the output type. "output_type" can be - "zarr" or "AnnotationStore". Default value is "zarr". + "dict", "zarr" or "AnnotationStore". Default value is "zarr". When saving in the zarr format the output is saved using the `python zarr library `__ as a zarr group. If the required output type is an "AnnotationStore" @@ -1087,6 +1227,7 @@ def run( ioconfig=ioconfig, overwrite=overwrite, patch_mode=patch_mode, + output_type=output_type, **kwargs, ) @@ -1101,4 +1242,8 @@ def run( # highest_input_resolution, implement dataloader, # pre-processing, post-processing and save_output # for WSIs separately. - raise NotImplementedError + return self._run_wsi_mode( + output_type=output_type, + save_dir=save_dir, + **kwargs, + ) diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 5837693e2..b98c6676d 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -1,33 +1,26 @@ -"""This module implements patch level prediction.""" +"""Defines Abstract Base Class for TIAToolbox Model Engines.""" from __future__ import annotations -import copy -from collections import OrderedDict -from pathlib import Path -from typing import TYPE_CHECKING, Callable, NoReturn +from typing import TYPE_CHECKING -import numpy as np -import torch -import tqdm +from typing_extensions import Unpack -import tiatoolbox.models.models_abc -from tiatoolbox import logger -from tiatoolbox.models.dataset.dataset_abc import PatchDataset, WSIPatchDataset -from tiatoolbox.utils import save_as_json -from tiatoolbox.wsicore.wsireader import VirtualWSIReader, WSIReader +from .engine_abc import EngineABC, EngineABCRunParams if TYPE_CHECKING: # pragma: no cover import os + from pathlib import Path + + import numpy as np + from torch.utils.data import DataLoader from tiatoolbox.annotation import AnnotationStore - from tiatoolbox.typing import IntPair, Resolution, Units + from tiatoolbox.models.models_abc import ModelABC + from tiatoolbox.wsicore.wsireader import WSIReader from .io_config import ModelIOConfigABC -from .engine_abc import EngineABC -from .io_config import IOPatchPredictorConfig - class PatchPredictor(EngineABC): r"""Patch level predictor for digital histology images. @@ -117,83 +110,161 @@ class PatchPredictor(EngineABC): - 0.867 Args: - model (nn.Module): - Use externally defined PyTorch model for prediction with. - weights already loaded. Default is `None`. If provided, - `pretrained_model` argument is ignored. - pretrained_model (str): - Name of the existing models support by tiatoolbox for - processing the data. For a full list of pretrained models, - refer to the `docs + model (str | ModelABC): + A PyTorch model or name of pretrained model. + The user can request pretrained models from the toolbox model zoo using + the list of pretrained models available at this `link `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument - is case-insensitive. - weights (str): - Path to the weight of the corresponding `pretrained_model`. - - >>> predictor = PatchPredictor( - ... pretrained_model="resnet18-kather100k", - ... weights="resnet18_local_weight") - + of weights using the `weights` parameter. Default is `None`. batch_size (int): - Number of images fed into the model each time. + Number of image patches fed into the model each time in a + forward/backward pass. Default value is 8. num_loader_workers (int): - Number of workers to load the data. Take note that they will - also perform preprocessing. + Number of workers to load the data using :class:`torch.utils.data.Dataset`. + Please note that they will also perform preprocessing. Default value is 0. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + Default value is 0. + weights (str or Path): + Path to the weight of the corresponding `model`. + + >>> engine = EngineABC( + ... model="pretrained-model", + ... weights="/path/to/pretrained-local-weights.pth" + ... ) + + device (str): + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default is "cpu". verbose (bool): - Whether to output logging information. + Whether to output logging information. Default value is False. Attributes: - images (str or :obj:`pathlib.Path` or :obj:`numpy.ndarray`): - A HWC image or a path to WSI. - mode (str): - Type of input to process. Choose from either `patch`, `tile` - or `wsi`. - model (nn.Module): - Defined PyTorch model. - model (str): - Name of the existing models support by tiatoolbox for - processing the data. For a full list of pretrained models, + images (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. + masks (list of str or list of :obj:`Path` or NHWC :obj:`numpy.ndarray`): + A list of tissue masks or binary masks corresponding to processing area of + input images. These can be a list of numpy arrays or paths to + the saved image masks. These are only utilized when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + patch_mode (str): + Whether to treat input images as a set of image patches. TIAToolbox defines + an image as a patch if HWC of the input image matches with the HWC expected + by the model. If HWC of the input image does not match with the HWC expected + by the model, then the patch_mode must be set to False which will allow the + engine to extract patches from the input image. + In this case, when the patch_mode is False the input images are treated + as WSIs. Default value is True. + model (str | ModelABC): + A PyTorch model or a name of an existing model from the TIAToolbox model zoo + for processing the data. For a full list of pretrained models, refer to the `docs `_ By default, the corresponding pretrained weights will also be downloaded. However, you can override with your own set - of weights via the `pretrained_weights` argument. Argument - is case insensitive. + of weights via the `weights` argument. Argument + is case-insensitive. + ioconfig (ModelIOConfigABC): + Input IO configuration of type :class:`ModelIOConfigABC` to run the Engine. + _ioconfig (ModelIOConfigABC): + Runtime ioconfig. + return_labels (bool): + Whether to return the labels with the predictions. + merge_predictions (bool): + Whether to merge the predictions to form a 2-dimensional + map. This is only applicable if `patch_mode` is False in inference. + Default is False. + resolution (Resolution): + Resolution used for reading the image. Please see + :obj:`WSIReader` for details. + units (Units): + Units of resolution used for reading the image. Choose + from either `level`, `power` or `mpp`. Please see + :obj:`WSIReader` for details. + patch_input_shape (tuple): + Shape of patches input to the model as tupled of HW. Patches are at + requested read resolution, not with respect to level 0, + and must be positive. + stride_shape (tuple): + Stride used during WSI processing. Stride is + at requested read resolution, not with respect to + level 0, and must be positive. If not provided, + `stride_shape=patch_input_shape`. batch_size (int): Number of images fed into the model each time. - num_loader_worker (int): - Number of workers used in torch.utils.data.DataLoader. + cache_mode (bool): + Whether to run the Engine in cache_mode. For large datasets, + we recommend to set this to True to avoid out of memory errors. + For smaller datasets, the cache_mode is set to False as + the results can be saved in memory. cache_mode is always True when + processing WSIs i.e., when `patch_mode` is False. Default value is False. + cache_size (int): + Specifies how many image patches to process in a batch when + cache_mode is set to True. If cache_size is less than the batch_size + batch_size is set to cache_size. Default value is 10,000. + labels (list | None): + List of labels. Only a single label per image is supported. + device (str): + :class:`torch.device` to run the model. + Select the device to run the model. Please see + https://pytorch.org/docs/stable/tensor_attributes.html#torch.device + for more details on input parameters for device. Default value is "cpu". + num_loader_workers (int): + Number of workers used in :class:`torch.utils.data.DataLoader`. + num_post_proc_workers (int): + Number of workers to postprocess the results of the model. + return_labels (bool): + Whether to return the output labels. Default value is False. + merge_predictions (bool): + Whether to merge WSI predictions into a single file. Default value is False. + resolution (Resolution): + Resolution used for reading the image. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. Default value is 1.0. + units (Units): + Units of resolution used for reading the image. Choose + from either `baseline`, `level`, `power` or `mpp`. Please see + :class:`WSIReader` for details. + When `patch_mode` is True, the input image patches are expected to be at + the correct resolution and units. When `patch_mode` is False, the patches + are extracted at the requested resolution and units. + Default value is `baseline`. verbose (bool): - Whether to output logging information. + Whether to output logging information. Default value is False. Examples: >>> # list of 2 image patches as input >>> data = ['path/img.svs', 'path/img.svs'] - >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(data, mode='patch') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') >>> # array of list of 2 image patches as input >>> data = np.array([img1, img2]) - >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") + >>> predictor = PatchPredictor(model="resnet18-kather100k") >>> output = predictor.predict(data, mode='patch') >>> # list of 2 image patch files as input >>> data = ['path/img.png', 'path/img.png'] - >>> predictor = PatchPredictor(pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(data, mode='patch') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(data, mode='patch') >>> # list of 2 image tile files as input >>> tile_file = ['path/tile1.png', 'path/tile2.png'] - >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") - >>> output = predictor.predict(tile_file, mode='tile') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(tile_file, mode='tile') >>> # list of 2 wsi files as input >>> wsi_file = ['path/wsi1.svs', 'path/wsi2.svs'] - >>> predictor = PatchPredictor(pretraind_model="resnet18-kather100k") - >>> output = predictor.predict(wsi_file, mode='wsi') + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(wsi_file, mode='wsi') References: [1] Kather, Jakob Nikolas, et al. "Predicting survival from colorectal cancer @@ -208,526 +279,143 @@ class PatchPredictor(EngineABC): def __init__( self: PatchPredictor, + model: str | ModelABC, batch_size: int = 8, num_loader_workers: int = 0, num_post_proc_workers: int = 0, - model: torch.nn.Module = None, - pretrained_model: str | None = None, - weights: str | None = None, + weights: str | Path | None = None, *, + device: str = "cpu", verbose: bool = True, ) -> None: """Initialize :class:`PatchPredictor`.""" super().__init__( + model=model, batch_size=batch_size, num_loader_workers=num_loader_workers, num_post_proc_workers=num_post_proc_workers, - model=model, - pretrained_model=pretrained_model, weights=weights, + device=device, verbose=verbose, ) - def pre_process_wsi(self: PatchPredictor) -> NoReturn: - """Pre-process a WSI.""" - - def infer_wsi(self: PatchPredictor) -> NoReturn: - """Model inference on a WSI.""" - - def save_predictions( + def get_dataloader( self: PatchPredictor, - raw_predictions: dict, - output_type: str, - ) -> None: - """Post-process an image patch.""" - - def save_wsi_output(self: PatchPredictor) -> NoReturn: - """Post-process a WSI.""" - - @staticmethod - def merge_predictions( - img: str | Path | np.ndarray, - output: dict, - resolution: Resolution | None = None, - units: Units | None = None, - post_proc_func: Callable | None = None, + images: Path, + masks: Path | None = None, + labels: list | None = None, + ioconfig: ModelIOConfigABC | None = None, *, - return_raw: bool = False, - ) -> np.ndarray: - """Merge patch level predictions to form a 2-dimensional prediction map. - - #! Improve how the below reads. - The prediction map will contain values from 0 to N, where N is - the number of classes. Here, 0 is the background which has not - been processed by the model and N is the number of classes - predicted by the model. + patch_mode: bool = True, + ) -> DataLoader: + """Pre-process images and masks and return dataloader for inference. Args: - img (:obj:`str` or :obj:`pathlib.Path` or :class:`numpy.ndarray`): - A HWC image or a path to WSI. - output (dict): - Output generated by the model. - resolution (Resolution): - Resolution of merged predictions. - units (Units): - Units of resolution used when merging predictions. This - must be the same `units` used when processing the data. - post_proc_func (callable): - A function to post-process raw prediction from model. By - default, internal code uses the `np.argmax` function. - return_raw (bool): - Return raw result without applying the `postproc_func` - on the assembled image. + images (list of str or :class:`Path` or :class:`numpy.ndarray`): + A list of image patches in NHWC format as a numpy array + or a list of str/paths to WSIs. When `patch_mode` is False + the function expects list of str/paths to WSIs. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + ioconfig (ModelIOConfigABC): + A :class:`ModelIOConfigABC` object. + patch_mode (bool): + Whether to treat input image as a patch or WSI. Returns: - :class:`numpy.ndarray`: - Merged predictions as a 2D array. + DataLoader: + :class:`DataLoader` for inference. - Examples: - >>> # pseudo output dict from model with 2 patches - >>> output = { - ... 'resolution': 1.0, - ... 'units': 'baseline', - ... 'probabilities': [[0.45, 0.55], [0.90, 0.10]], - ... 'predictions': [1, 0], - ... 'coordinates': [[0, 0, 2, 2], [2, 2, 4, 4]], - ... } - >>> merged = PatchPredictor.merge_predictions( - ... np.zeros([4, 4]), - ... output, - ... resolution=1.0, - ... units='baseline' - ... ) - >>> merged - ... array([[2, 2, 0, 0], - ... [2, 2, 0, 0], - ... [0, 0, 1, 1], - ... [0, 0, 1, 1]]) """ - reader = WSIReader.open(img) - if isinstance(reader, VirtualWSIReader): - logger.warning( - "Image is not pyramidal hence read is forced to be " - "at `units='baseline'` and `resolution=1.0`.", - stacklevel=2, - ) - resolution = 1.0 - units = "baseline" - - canvas_shape = reader.slide_dimensions(resolution=resolution, units=units) - canvas_shape = canvas_shape[::-1] # XY to YX - - # may crash here, do we need to deal with this ? - output_shape = reader.slide_dimensions( - resolution=output["resolution"], - units=output["units"], - ) - output_shape = output_shape[::-1] # XY to YX - fx = np.array(canvas_shape) / np.array(output_shape) - - if "probabilities" not in output: - coordinates = output["coordinates"] - predictions = output["predictions"] - denominator = None - output = np.zeros(list(canvas_shape), dtype=np.float32) - else: - coordinates = output["coordinates"] - predictions = output["probabilities"] - num_class = np.array(predictions[0]).shape[0] - denominator = np.zeros(canvas_shape) - output = np.zeros([*list(canvas_shape), num_class], dtype=np.float32) - - for idx, bound in enumerate(coordinates): - prediction = predictions[idx] - # assumed to be in XY - # top-left for output placement - tl = np.ceil(np.array(bound[:2]) * fx).astype(np.int32) - # bot-right for output placement - br = np.ceil(np.array(bound[2:]) * fx).astype(np.int32) - output[tl[1] : br[1], tl[0] : br[0]] += prediction - if denominator is not None: - denominator[tl[1] : br[1], tl[0] : br[0]] += 1 - - # deal with overlapping regions - if denominator is not None: - output = output / (np.expand_dims(denominator, -1) + 1.0e-8) - if not return_raw: - # convert raw probabilities to predictions - if post_proc_func is not None: - output = post_proc_func(output) - else: - output = np.argmax(output, axis=-1) - # to make sure background is 0 while class will be 1...N - output[denominator > 0] += 1 - return output - - def _predict_engine( - self: PatchPredictor, - dataset: torch.utils.data.Dataset, - *, - return_probabilities: bool = False, - return_labels: bool = False, - return_coordinates: bool = False, - device: str = "cpu", - ) -> np.ndarray: - """Make a prediction on a dataset. The dataset may be mutated. - - Args: - dataset (torch.utils.data.Dataset): - PyTorch dataset object created using - `tiatoolbox.models.data.classification.Patch_Dataset`. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return labels. - return_coordinates (bool): - Whether to return patch coordinates. - device (str): - Select the device to run the model. Default is "cpu". - - Returns: - :class:`numpy.ndarray`: - Model predictions of the input dataset - - """ - dataset.preproc_func = self.model.preproc_func - - # preprocessing must be defined with the dataset - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=self.num_loader_workers, - batch_size=self.batch_size, - drop_last=False, - shuffle=False, + return super().get_dataloader( + images, + masks, + labels, + ioconfig, + patch_mode=patch_mode, ) - if self.verbose: - pbar = tqdm.tqdm( - total=int(len(dataloader)), - leave=True, - ncols=80, - ascii=True, - position=0, - ) - - # use external for testing - model = tiatoolbox.models.models_abc.model_to(model=self.model, device=device) - - cum_output = { - "probabilities": [], - "predictions": [], - "coordinates": [], - "labels": [], - } - for _, batch_data in enumerate(dataloader): - batch_output_probabilities = self.model.infer_batch( - model, - batch_data["image"], - device=device, - ) - # We get the index of the class with the maximum probability - batch_output_predictions = self.model.postproc_func( - batch_output_probabilities, - ) - - # tolist might be very expensive - cum_output["probabilities"].extend(batch_output_probabilities.tolist()) - cum_output["predictions"].extend(batch_output_predictions.tolist()) - if return_coordinates: - cum_output["coordinates"].extend(batch_data["coords"].tolist()) - if return_labels: # be careful of `s` - # We do not use tolist here because label may be of mixed types - # and hence collated as list by torch - cum_output["labels"].extend(list(batch_data["label"])) - - if self.verbose: - pbar.update() - if self.verbose: - pbar.close() - - if not return_probabilities: - cum_output.pop("probabilities") - if not return_labels: - cum_output.pop("labels") - if not return_coordinates: - cum_output.pop("coordinates") - - return cum_output - - def _update_ioconfig( - self: PatchPredictor, - ioconfig: IOPatchPredictorConfig, - patch_input_shape: IntPair, - stride_shape: IntPair, - resolution: Resolution, - units: Units, - ) -> IOPatchPredictorConfig: - """Update the ioconfig. + def infer_wsi( + self: EngineABC, + dataloader: DataLoader, + save_path: Path, + **kwargs: EngineABCRunParams, + ) -> Path: + """Model inference on a WSI. Args: - ioconfig (:class:`IOPatchPredictorConfig`): - Input ioconfig for PatchPredictor. - patch_input_shape (tuple): - Size of patches input to the model. Patches are at - requested read resolution, not with respect to level 0, - and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. - resolution (Resolution): - Resolution used for reading the image. Please see - :obj:`WSIReader` for details. - units (Units): - Units of resolution used for reading the image. + dataloader (DataLoader): + A torch dataloader to process WSIs. + + save_path (Path): + Path to save the intermediate output. The intermediate output is saved + in a zarr file. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. See + :class:`EngineRunParams` for accepted keyword arguments. Returns: - Updated Patch Predictor IO configuration. + save_path (Path): + Path to zarr file where intermediate output is saved. """ - config_flag = ( - patch_input_shape is None, - resolution is None, - units is None, + _ = kwargs.get("patch_mode", False) + return self.infer_patches( + dataloader=dataloader, + save_path=save_path, + return_coordinates=True, ) - if ioconfig: - return ioconfig - - if self.ioconfig is None and any(config_flag): - msg = ( - "Must provide either " - "`ioconfig` or `patch_input_shape`, `resolution`, and `units`." - ) - raise ValueError( - msg, - ) - - if stride_shape is None: - stride_shape = patch_input_shape - - if self.ioconfig: - ioconfig = copy.deepcopy(self.ioconfig) - # ! not sure if there is a nicer way to set this - if patch_input_shape is not None: - ioconfig.patch_input_shape = patch_input_shape - if stride_shape is not None: - ioconfig.stride_shape = stride_shape - if resolution is not None: - ioconfig.input_resolutions[0]["resolution"] = resolution - if units is not None: - ioconfig.input_resolutions[0]["units"] = units - - return ioconfig - - return IOPatchPredictorConfig( - input_resolutions=[{"resolution": resolution, "units": units}], - patch_input_shape=patch_input_shape, - stride_shape=stride_shape, - output_resolutions=[], - ) - - def _predict_patch( - self: PatchPredictor, - imgs: list | np.ndarray, - labels: list, - *, - return_probabilities: bool, - return_labels: bool, - device: str, - ) -> np.ndarray: - """Process patch mode. - Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. - labels: - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return the labels with the predictions. - device (str): - Select the device to run the engine. + def post_process_wsi( + self: EngineABC, + raw_predictions: dict | Path, + **kwargs: Unpack[EngineABCRunParams], + ) -> dict | Path: + """Post process WSI output. - Returns: - :class:`numpy.ndarray`: - Model predictions of the input dataset + Takes the raw output from patch predictions and post-processes it to improve the + results e.g., using information from neighbouring patches. """ - if labels: - # if a labels is provided, then return with the prediction - return_labels = bool(labels) - - if labels and len(labels) != len(imgs): - msg = f"len(labels) != len(imgs) : {len(labels)} != {len(imgs)}" - raise ValueError( - msg, - ) - - # don't return coordinates if patches are already extracted - return_coordinates = False - dataset = PatchDataset(imgs, labels) - return self._predict_engine( - dataset, - return_probabilities=return_probabilities, - return_labels=return_labels, - return_coordinates=return_coordinates, - device=device, + return super().post_process_wsi( + raw_predictions=raw_predictions, + **kwargs, ) - def _predict_tile_wsi( # noqa: PLR0913 - self: PatchPredictor, - imgs: list, - masks: list | None, - labels: list, - mode: str, - ioconfig: IOPatchPredictorConfig, - save_dir: str | Path, - highest_input_resolution: list[dict], - *, - save_output: bool, - return_probabilities: bool, - merge_predictions: bool, - on_gpu: bool, - ) -> list | dict: - """Predict on Tile and WSIs. + def save_wsi_output( + self: EngineABC, + processed_output: Path, + output_type: str, + **kwargs: Unpack[EngineABCRunParams], + ) -> Path: + """Aggregate the output at the WSI level and save to file. Args: - imgs (list, ndarray): - List of inputs to process. when using `patch` mode, the - input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. - masks (list or None): - List of masks. Only utilised when processing image tiles - and whole-slide images. Patches are only processed if - they are within a masked area. If not provided, then a - tissue mask will be automatically generated for - whole-slide images or the entire image is processed for - image tiles. - labels: - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - mode (str): - Type of input to process. Choose from either `patch`, - `tile` or `wsi`. - return_probabilities (bool): - Whether to return per-class probabilities. - on_gpu (bool): - Whether to run model on the GPU. - ioconfig (IOPatchPredictorConfig): - Patch Predictor IO configuration.. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable for `mode='wsi'` or - `mode='tile'`. - save_dir (str or pathlib.Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. - save_output (bool): - Whether to save output for a single file. default=False - highest_input_resolution (list(dict)): - Highest available input resolution. - - - Returns: - dict: - Results are saved to `save_dir` and a dictionary indicating save - location for each input is returned. The dict is in the following - format: - - img_path: path of the input image. - - raw: path to save location for raw prediction, - saved in .json. - - merged: path to .npy contain merged - predictions if - `merge_predictions` is `True`. + processed_output (Path): + Path to Zarr file with intermediate results. + output_type (str): + The desired output type for resulting patch dataset. + **kwargs (EngineABCRunParams): + Keyword Args to update setup_patch_dataset() method attributes. + + Returns: (AnnotationStore or Path): + If the output_type is "AnnotationStore", the function returns the patch + predictor output as an SQLiteStore containing Annotations stored in a `.db` + file. Otherwise, the function defaults to returning patch predictor output + stored in a `.zarr` file. """ - # return coordinates of patches processed within a tile / whole-slide image - return_coordinates = True - - input_is_path_like = isinstance(imgs[0], (str, Path)) - default_save_dir = ( - imgs[0].parent / "output" if input_is_path_like else Path.cwd() + return super().save_wsi_output( + processed_output=processed_output, + output_type=output_type, + **kwargs, ) - save_dir = default_save_dir if save_dir is None else Path(save_dir) - - # None if no output - outputs = None - - self._ioconfig = ioconfig - # generate a list of output file paths if number of input images > 1 - file_dict = OrderedDict() - - if len(imgs) > 1: - save_output = True - - for idx, img_path in enumerate(imgs): - img_path_ = Path(img_path) - img_label = None if labels is None else labels[idx] - img_mask = None if masks is None else masks[idx] - - dataset = WSIPatchDataset( - img_path_, - mode=mode, - mask_path=img_mask, - patch_input_shape=ioconfig.patch_input_shape, - stride_shape=ioconfig.stride_shape, - resolution=ioconfig.input_resolutions[0]["resolution"], - units=ioconfig.input_resolutions[0]["units"], - ) - output_model = self._predict_engine( - dataset, - return_labels=False, - return_probabilities=return_probabilities, - return_coordinates=return_coordinates, - on_gpu=on_gpu, - ) - output_model["label"] = img_label - # add extra information useful for downstream analysis - output_model["pretrained_model"] = self.model - output_model["resolution"] = highest_input_resolution["resolution"] - output_model["units"] = highest_input_resolution["units"] - - outputs = [output_model] # assign to a list - merged_prediction = None - if merge_predictions: - merged_prediction = self.merge_predictions( - img_path_, - output_model, - resolution=output_model["resolution"], - units=output_model["units"], - post_proc_func=self.model.postproc, - ) - outputs.append(merged_prediction) - - if save_output: - # dynamic 0 padding - img_code = f"{idx:0{len(str(len(imgs)))}d}" - - save_info = {} - save_path = save_dir / img_code - raw_save_path = f"{save_path}.raw.json" - save_info["raw"] = raw_save_path - save_as_json(output_model, raw_save_path) - if merge_predictions: - merged_file_path = f"{save_path}.merged.npy" - np.save(merged_file_path, merged_prediction) - save_info["merged"] = merged_file_path - file_dict[str(img_path_)] = save_info - - return file_dict if save_output else outputs def run( self: EngineABC, @@ -740,102 +428,51 @@ def run( save_dir: os | Path | None = None, # None will not save output overwrite: bool = False, output_type: str = "dict", - **kwargs: dict, - ) -> AnnotationStore | str: - """Run engine.""" - super().run( - images=images, - masks=masks, - labels=labels, - ioconfig=ioconfig, - patch_mode=patch_mode, - save_dir=save_dir, - overwrite=overwrite, - output_type=output_type, - **kwargs, - ) - - def predict( # noqa: PLR0913 - self: PatchPredictor, - imgs: list, - masks: list | None = None, - labels: list | None = None, - mode: str = "patch", - ioconfig: IOPatchPredictorConfig | None = None, - patch_input_shape: tuple[int, int] | None = None, - stride_shape: tuple[int, int] | None = None, - resolution: Resolution | None = None, - units: Units = None, - *, - return_probabilities: bool = False, - return_labels: bool = False, - on_gpu: bool = True, - merge_predictions: bool = False, - save_dir: str | Path | None = None, - save_output: bool = False, - ) -> np.ndarray | list | dict: - """Make a prediction for a list of input data. + **kwargs: Unpack[EngineABCRunParams], + ) -> AnnotationStore | Path | str | dict: + """Run the engine on input images. Args: - imgs (list, ndarray): + images (list, ndarray): List of inputs to process. when using `patch` mode, the input must be either a list of images, a list of image - file paths or a numpy array of an image list. When using - `tile` or `wsi` mode, the input must be a list of file - paths. - masks (list): - List of masks. Only utilised when processing image tiles - and whole-slide images. Patches are only processed if - they are within a masked area. If not provided, then a - tissue mask will be automatically generated for - whole-slide images or the entire image is processed for - image tiles. - labels: - List of labels. If using `tile` or `wsi` mode, then only - a single label per image tile or whole-slide image is - supported. - mode (str): - Type of input to process. Choose from either `patch`, - `tile` or `wsi`. - return_probabilities (bool): - Whether to return per-class probabilities. - return_labels (bool): - Whether to return the labels with the predictions. - on_gpu (bool): - Whether to run model on the GPU. + file paths or a numpy array of an image list. + masks (list | None): + List of masks. Only utilised when patch_mode is False. + Patches are only generated within a masked area. + If not provided, then a tissue mask will be automatically + generated for whole slide images. + labels (list | None): + List of labels. Only a single label per image is supported. + patch_mode (bool): + Whether to treat input image as a patch or WSI. + default = True. ioconfig (IOPatchPredictorConfig): - Patch Predictor IO configuration. - patch_input_shape (tuple): - Size of patches input to the model. Patches are at - requested read resolution, not with respect to level 0, - and must be positive. - stride_shape (tuple): - Stride using during tile and WSI processing. Stride is - at requested read resolution, not with respect to - level 0, and must be positive. If not provided, - `stride_shape=patch_input_shape`. - resolution (Resolution): - Resolution used for reading the image. Please see - :obj:`WSIReader` for details. - units (Units): - Units of resolution used for reading the image. Choose - from either `level`, `power` or `mpp`. Please see - :obj:`WSIReader` for details. - merge_predictions (bool): - Whether to merge the predictions to form a 2-dimensional - map. This is only applicable for `mode='wsi'` or - `mode='tile'`. + IO configuration. save_dir (str or pathlib.Path): - Output directory when processing multiple tiles and - whole-slide images. By default, it is folder `output` - where the running script is invoked. - save_output (bool): - Whether to save output for a single file. default=False + Output directory to save the results. + If save_dir is not provided when patch_mode is False, + then for a single image the output is created in the current directory. + If there are multiple WSIs as input then the user must provide + path to save directory otherwise an OSError will be raised. + overwrite (bool): + Whether to overwrite the results. Default = False. + output_type (str): + The format of the output type. "output_type" can be + "zarr" or "AnnotationStore". Default value is "zarr". + When saving in the zarr format the output is saved using the + `python zarr library `__ + as a zarr group. If the required output type is an "AnnotationStore" + then the output will be intermediately saved as zarr but converted + to :class:`AnnotationStore` and saved as a `.db` file + at the end of the loop. + **kwargs (EngineABCRunParams): + Keyword Args to update :class:`EngineABC` attributes during runtime. Returns: (:class:`numpy.ndarray`, dict): Model predictions of the input dataset. If multiple - image tiles or whole-slide images are provided as input, + whole slide images are provided as input, or save_output is True, then results are saved to `save_dir` and a dictionary indicating save location for each input is returned. @@ -850,79 +487,34 @@ def predict( # noqa: PLR0913 Examples: >>> wsis = ['wsi1.svs', 'wsi2.svs'] - >>> predictor = PatchPredictor( - ... pretrained_model="resnet18-kather100k") - >>> output = predictor.predict(wsis, mode="wsi") + >>> class PatchPredictor(EngineABC): + >>> # Define all Abstract methods. + >>> ... + >>> predictor = PatchPredictor(model="resnet18-kather100k") + >>> output = predictor.run(image_patches, patch_mode=True) + >>> output + ... "/path/to/Output.db" + >>> output = predictor.run( + >>> image_patches, + >>> patch_mode=True, + >>> output_type="zarr") + >>> output + ... "/path/to/Output.zarr" + >>> output = predictor.run(wsis, patch_mode=False) >>> output.keys() ... ['wsi1.svs', 'wsi2.svs'] >>> output['wsi1.svs'] - ... {'raw': '0.raw.json', 'merged': '0.merged.npy'} - >>> output['wsi2.svs'] - ... {'raw': '1.raw.json', 'merged': '1.merged.npy'} + ... {'/path/to/wsi1.db'} """ - if mode not in ["patch", "wsi", "tile"]: - msg = f"{mode} is not a valid mode. Use either `patch`, `tile` or `wsi`" - raise ValueError( - msg, - ) - if mode == "patch": - return self._predict_patch( - imgs, - labels, - return_probabilities, - return_labels, - on_gpu, - ) - - if not isinstance(imgs, list): - msg = "Input to `tile` and `wsi` mode must be a list of file paths." - raise TypeError( - msg, - ) - - if mode == "wsi" and masks is not None and len(masks) != len(imgs): - msg = f"len(masks) != len(imgs) : {len(masks)} != {len(imgs)}" - raise ValueError( - msg, - ) - - ioconfig = self._update_ioconfig( - ioconfig, - patch_input_shape, - stride_shape, - resolution, - units, - ) - if mode == "tile": - logger.warning( - "WSIPatchDataset only reads image tile at " - '`units="baseline"`. Resolutions will be converted ' - "to baseline value.", - stacklevel=2, - ) - ioconfig = ioconfig.to_baseline() - - fx_list = ioconfig.scale_to_highest( - ioconfig.input_resolutions, - ioconfig.input_resolutions[0]["units"], - ) - fx_list = zip(fx_list, ioconfig.input_resolutions) - fx_list = sorted(fx_list, key=lambda x: x[0]) - highest_input_resolution = fx_list[0][1] - - save_dir = self._prepare_save_dir(save_dir, imgs) - - return self._predict_tile_wsi( - imgs, - masks, - labels, - mode, - return_probabilities, - on_gpu, - ioconfig, - merge_predictions, - save_dir, - save_output, - highest_input_resolution, + return super().run( + images=images, + masks=masks, + labels=labels, + ioconfig=ioconfig, + patch_mode=patch_mode, + save_dir=save_dir, + overwrite=overwrite, + output_type=output_type, + **kwargs, ) diff --git a/tiatoolbox/utils/misc.py b/tiatoolbox/utils/misc.py index 59a14c48c..5a3c8c7fe 100644 --- a/tiatoolbox/utils/misc.py +++ b/tiatoolbox/utils/misc.py @@ -1201,33 +1201,40 @@ def add_from_dat( def patch_predictions_as_annotations( - preds: list, + preds: list | np.ndarray, keys: list, class_dict: dict, - class_probs: list, + class_probs: list | np.ndarray, patch_coords: list, classes_predicted: list, labels: list, ) -> list: """Helper function to generate annotation per patch predictions.""" annotations = [] - for i, pred in enumerate(preds): + for i, probs in enumerate(class_probs): if "probabilities" in keys: - props = { - f"prob_{class_dict[j]}": class_probs[i][j] for j in classes_predicted - } + props = {f"prob_{class_dict[j]}": probs[j] for j in classes_predicted} else: props = {} if "labels" in keys: props["label"] = class_dict[labels[i]] - props["type"] = class_dict[pred] + if len(preds) > 0: + props["type"] = class_dict[preds[i]] annotations.append(Annotation(Polygon.from_bounds(*patch_coords[i]), props)) return annotations +def _get_zarr_array(zarr_array: zarr.core.Array | np.ndarray) -> np.ndarray: + """Converts a zarr array into a numpy array.""" + if isinstance(zarr_array, zarr.core.Array): + return zarr_array[:] + + return zarr_array + + def dict_to_store( - patch_output: dict, + patch_output: dict | zarr.group, scale_factor: tuple[float, float], class_dict: dict | None = None, save_path: Path | None = None, @@ -1235,7 +1242,7 @@ def dict_to_store( """Converts (and optionally saves) output of TIAToolbox engines as AnnotationStore. Args: - patch_output (dict): + patch_output (dict | zarr.Group): A dictionary with "probabilities", "predictions", "coordinates", and "labels" keys. scale_factor (tuple[float, float]): @@ -1260,9 +1267,10 @@ def dict_to_store( # we cant create annotations without coordinates msg = "Patch output must contain coordinates." raise ValueError(msg) + # get relevant keys - class_probs = patch_output.get("probabilities", []) - preds = patch_output.get("predictions", []) + class_probs = _get_zarr_array(patch_output.get("probabilities", [])) + preds = _get_zarr_array(patch_output.get("predictions", [])) patch_coords = np.array(patch_output.get("coordinates", [])) if not np.all(np.array(scale_factor) == 1): @@ -1301,7 +1309,7 @@ def dict_to_store( # if a save director is provided, then dump store into a file if save_path: - # ensure parent directory exisits + # ensure parent directory exists save_path.parent.absolute().mkdir(parents=True, exist_ok=True) # ensure proper db extension save_path = save_path.parent.absolute() / (save_path.stem + ".db") @@ -1341,15 +1349,15 @@ def dict_to_zarr( save_path = save_path.parent.absolute() / (save_path.stem + ".zarr") # save to zarr - predictions_array = np.array(raw_predictions["predictions"]) + probabilities_array = np.array(raw_predictions["probabilities"]) z = zarr.open( - save_path, + str(save_path), mode="w", - shape=predictions_array.shape, + shape=probabilities_array.shape, chunks=chunks, compressor=compressor, ) - z[:] = predictions_array + z[:] = probabilities_array return save_path @@ -1463,7 +1471,8 @@ def write_to_zarr_in_cache_mode( Zarr group name consisting of zarr(s) to save the batch output values. output_data_to_save (dict): - Output data from the Engine to save to Zarr. + Output data from the Engine to save to Zarr. Expects the data saved in + dictionary to be a numpy array. **kwargs (dict): Keyword Args to update zarr_group attributes. @@ -1486,6 +1495,8 @@ def write_to_zarr_in_cache_mode( ) zarr_dataset[:] = data_to_save + return zarr_group + # case 2 - append to existing zarr group for key in output_data_to_save: zarr_group[key].append(output_data_to_save[key])