Skip to content

Commit

Permalink
πŸ§‘β€πŸ’» Define PatchPredictor (#783)
Browse files Browse the repository at this point in the history
- Redesigns PatchPredictor engine using the new EngineABC base class.
- The WSIs are now processed using the same code as for the processing the patches using WSI based dataloader.
- The intermediate output is saved as zarr for the WSIs to resolve memory issues.
- The output of model architectures should now be a dictionary.
- The output can be specified as AnnotationStore for visualisation using TIAViz.

---------

Co-authored-by: abishekrajvg <abishekraj6797@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 20, 2024
1 parent 05b1b1d commit 81e575d
Show file tree
Hide file tree
Showing 13 changed files with 983 additions and 1,748 deletions.
107 changes: 78 additions & 29 deletions tests/engines/test_engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import copy
import logging
import shutil
from pathlib import Path
from typing import TYPE_CHECKING, NoReturn
Expand All @@ -11,15 +12,21 @@
import pytest
import torchvision.models as torch_models
import zarr
from typing_extensions import Unpack

from tiatoolbox.models.architecture import (
fetch_pretrained_weights,
get_pretrained_model,
)
from tiatoolbox.models.architecture.vanilla import CNNModel
from tiatoolbox.models.dataset import PatchDataset, WSIPatchDataset
from tiatoolbox.models.engine.engine_abc import EngineABC, prepare_engines_save_dir
from tiatoolbox.models.engine.engine_abc import (
EngineABC,
EngineABCRunParams,
prepare_engines_save_dir,
)
from tiatoolbox.models.engine.io_config import ModelIOConfigABC
from tiatoolbox.utils.misc import write_to_zarr_in_cache_mode

if TYPE_CHECKING:
import torch.nn
Expand Down Expand Up @@ -57,31 +64,38 @@ def get_dataloader(

def save_wsi_output(
self: EngineABC,
raw_output: dict,
processed_output: dict,
save_dir: Path,
**kwargs: dict,
) -> Path:
"""Test post_process_wsi."""
return super().save_wsi_output(
raw_output,
processed_output,
save_dir=save_dir,
**kwargs,
)

def post_process_wsi(
self: EngineABC,
raw_predictions: dict | Path,
**kwargs: Unpack[EngineABCRunParams],
) -> dict | Path:
"""Post process WSI output."""
return super().post_process_wsi(
raw_predictions=raw_predictions,
**kwargs,
)

def infer_wsi(
self: EngineABC,
dataloader: torch.utils.data.DataLoader,
img_label: str,
highest_input_resolution: list[dict],
save_dir: Path,
save_path: Path,
**kwargs: dict,
) -> dict | np.ndarray:
"""Test infer_wsi."""
return super().infer_wsi(
dataloader,
img_label,
highest_input_resolution,
save_dir,
save_path,
**kwargs,
)

Expand Down Expand Up @@ -115,13 +129,34 @@ def test_incorrect_ioconfig() -> NoReturn:
"""Test EngineABC initialization with incorrect ioconfig."""
model = torch_models.resnet18()
engine = TestEngineABC(model=model)

with pytest.raises(
ValueError,
match=r".*provide a valid ModelIOConfigABC.*",
match=r".*Must provide.*`ioconfig`.*",
):
engine.run(images=[], masks=[], ioconfig=None)


def test_incorrect_output_type() -> NoReturn:
"""Test EngineABC for incorrect output type."""
pretrained_model = "alexnet-kather100k"

# Test engine run without ioconfig
eng = TestEngineABC(model=pretrained_model)

with pytest.raises(
TypeError,
match=r".*output_type must be 'dict' or 'zarr' or 'annotationstore*",
):
_ = eng.run(
images=np.zeros((10, 224, 224, 3), dtype=np.uint8),
on_gpu=False,
patch_mode=True,
ioconfig=None,
output_type="random",
)


def test_pretrained_ioconfig() -> NoReturn:
"""Test EngineABC initialization with pretrained model name in the toolbox."""
pretrained_model = "alexnet-kather100k"
Expand All @@ -134,7 +169,7 @@ def test_pretrained_ioconfig() -> NoReturn:
patch_mode=True,
ioconfig=None,
)
assert "predictions" in out
assert "probabilities" in out
assert "labels" not in out


Expand All @@ -153,7 +188,7 @@ def test_ioconfig() -> NoReturn:
ioconfig=ioconfig,
)

assert "predictions" in out
assert "probabilities" in out
assert "labels" not in out


Expand Down Expand Up @@ -260,7 +295,7 @@ def test_engine_initalization() -> NoReturn:
assert isinstance(eng, EngineABC)


def test_engine_run(tmp_path: Path) -> NoReturn:
def test_engine_run(tmp_path: Path, sample_svs: Path) -> NoReturn:
"""Test engine run."""
eng = TestEngineABC(model="alexnet-kather100k")
assert isinstance(eng, EngineABC)
Expand Down Expand Up @@ -316,7 +351,7 @@ def test_engine_run(tmp_path: Path) -> NoReturn:
on_gpu=False,
patch_mode=True,
)
assert "predictions" in out
assert "probabilities" in out
assert "labels" not in out

eng = TestEngineABC(model="alexnet-kather100k")
Expand All @@ -325,7 +360,7 @@ def test_engine_run(tmp_path: Path) -> NoReturn:
on_gpu=False,
verbose=False,
)
assert "predictions" in out
assert "probabilities" in out
assert "labels" not in out

eng = TestEngineABC(model="alexnet-kather100k")
Expand All @@ -334,14 +369,14 @@ def test_engine_run(tmp_path: Path) -> NoReturn:
labels=list(range(10)),
on_gpu=False,
)
assert "predictions" in out
assert "probabilities" in out
assert "labels" in out

eng = TestEngineABC(model="alexnet-kather100k")

with pytest.raises(NotImplementedError):
eng.run(
images=np.zeros(shape=(10, 224, 224, 3)),
images=[sample_svs],
save_dir=tmp_path / "output",
patch_mode=False,
)
Expand All @@ -358,7 +393,7 @@ def test_engine_run_with_verbose() -> NoReturn:
on_gpu=False,
)

assert "predictions" in out
assert "probabilities" in out
assert "labels" in out


Expand Down Expand Up @@ -513,21 +548,28 @@ def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None:
save_path = tmp_path / "output.zarr"
_ = zarr.open(save_path, mode="w")
out = eng.save_wsi_output(
raw_output=save_path, save_path=save_path, output_type="zarr", save_dir=tmp_path
processed_output=save_path,
save_path=save_path,
output_type="zarr",
save_dir=tmp_path,
)

assert out.exists()
assert out.suffix == ".zarr"

# Test AnnotationStore
patch_output = {
"predictions": [1, 0, 1],
"coordinates": [(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)],
"other": "other",
"predictions": np.array([1, 0, 1]),
"coordinates": np.array([(0, 0, 1, 1), (1, 1, 2, 2), (2, 2, 3, 3)]),
}
class_dict = {0: "class0", 1: "class1"}
save_path = tmp_path / "output_db.zarr"
zarr_group = zarr.open(save_path, mode="w")
_ = write_to_zarr_in_cache_mode(
zarr_group=zarr_group, output_data_to_save=patch_output
)
out = eng.save_wsi_output(
raw_output=patch_output,
processed_output=save_path,
scale_factor=(1.0, 1.0),
class_dict=class_dict,
save_dir=tmp_path,
Expand All @@ -542,28 +584,35 @@ def test_eng_save_output(tmp_path: pytest.TempPathFactory) -> None:
match=r".*supports zarr and AnnotationStore as output_type.",
):
eng.save_wsi_output(
raw_output=save_path,
processed_output=save_path,
save_path=save_path,
output_type="dict",
save_dir=tmp_path,
)


def test_io_config_delegation(tmp_path: Path) -> None:
def test_io_config_delegation(tmp_path: Path, caplog: pytest.LogCaptureFixture) -> None:
"""Test for delegating args to io config."""
# test not providing config / full input info for not pretrained models
model = CNNModel("resnet50")
eng = TestEngineABC(model=model)
with pytest.raises(ValueError, match=r".*Please provide a valid ModelIOConfigABC*"):
eng.run(
np.zeros((10, 224, 224, 3)), patch_mode=True, save_dir=tmp_path / "dump"
)

kwargs = {
"patch_input_shape": [512, 512],
"resolution": 1.75,
"units": "mpp",
}
with caplog.at_level(logging.WARNING):
eng.run(
np.zeros((10, 224, 224, 3)),
patch_mode=True,
save_dir=tmp_path / "dump",
patch_input_shape=kwargs["patch_input_shape"],
resolution=kwargs["resolution"],
units=kwargs["units"],
)
assert "provide a valid ModelIOConfigABC" in caplog.text
shutil.rmtree(tmp_path / "dump", ignore_errors=True)

# test providing config / full input info for non pretrained models
ioconfig = ModelIOConfigABC(
Expand Down
Loading

0 comments on commit 81e575d

Please sign in to comment.