Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add PatchClassifier Engine #865

Open
wants to merge 8 commits into
base: dev-define-engines-abc
Choose a base branch
from
108 changes: 108 additions & 0 deletions tests/engines/test_patch_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Test for Patch Classifier."""

from __future__ import annotations

import shutil
from pathlib import Path

import numpy as np
import zarr

from tiatoolbox.models.engine.patch_classifier import PatchClassifier
from tiatoolbox.utils import env_detection as toolbox_env

device = "cuda" if toolbox_env.has_gpu() else "cpu"


def _test_classifier_output(
inputs: list,
model: str,
probabilities_check: list | None = None,
classification_check: list | None = None,
tmp_path: Path | None = None,
) -> None:
"""Test the predictions of multiple models included in tiatoolbox."""
cache_mode = None if tmp_path is None else True
save_dir = None if tmp_path is None else tmp_path / "output"
classifier = PatchClassifier(
model=model,
batch_size=32,
verbose=False,
)
# don't run test on GPU
output = classifier.run(
inputs,
return_labels=False,
device=device,
cache_mode=cache_mode,
save_dir=save_dir,
)

if tmp_path is not None:
output = zarr.open(output, mode="r")

probabilities = output["probabilities"]
classification = output["predictions"]
for idx, probabilities_ in enumerate(probabilities):
probabilities_max = max(probabilities_)
assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, (
model,
probabilities_max,
probabilities_check[idx],
probabilities_,
classification_check[idx],
)
assert classification[idx] == classification_check[idx], (
model,
probabilities_max,
probabilities_check[idx],
probabilities_,
classification_check[idx],
)
if save_dir:
shutil.rmtree(save_dir)


def test_patch_predictor_kather100k_output(
sample_patch1: Path,
sample_patch2: Path,
tmp_path: Path,
) -> None:
"""Test the output of patch classification models on Kather100K dataset."""
inputs = [Path(sample_patch1), Path(sample_patch2)]
pretrained_info = {
"alexnet-kather100k": [1.0, 0.9999735355377197],
"resnet18-kather100k": [1.0, 0.9999911785125732],
"resnet34-kather100k": [1.0, 0.9979840517044067],
"resnet50-kather100k": [1.0, 0.9999986886978149],
"resnet101-kather100k": [1.0, 0.9999932050704956],
"resnext50_32x4d-kather100k": [1.0, 0.9910059571266174],
"resnext101_32x8d-kather100k": [1.0, 0.9999971389770508],
"wide_resnet50_2-kather100k": [1.0, 0.9953408241271973],
"wide_resnet101_2-kather100k": [1.0, 0.9999831914901733],
"densenet121-kather100k": [1.0, 1.0],
"densenet161-kather100k": [1.0, 0.9999959468841553],
"densenet169-kather100k": [1.0, 0.9999934434890747],
"densenet201-kather100k": [1.0, 0.9999983310699463],
"mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593],
"mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658],
"mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209],
"googlenet-kather100k": [1.0, 0.9999639987945557],
}
for model, expected_prob in pretrained_info.items():
_test_classifier_output(
inputs,
model,
probabilities_check=expected_prob,
classification_check=[6, 3],
)

# cache mode
for model, expected_prob in pretrained_info.items():
_test_classifier_output(
inputs,
model,
probabilities_check=expected_prob,
classification_check=[6, 3],
tmp_path=tmp_path,
)
138 changes: 15 additions & 123 deletions tests/engines/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,79 +239,19 @@ def test_wsi_predictor_api(
shutil.rmtree(_kwargs["save_dir"], ignore_errors=True)


def _test_predictor_output(
inputs: list,
model: str,
probabilities_check: list | None = None,
predictions_check: list | None = None,
) -> None:
"""Test the predictions of multiple models included in tiatoolbox."""
predictor = PatchPredictor(
model=model,
batch_size=32,
verbose=False,
)
# don't run test on GPU
output = predictor.run(
inputs,
return_probabilities=True,
return_labels=False,
device=device,
)
predictions = output["probabilities"]
for idx, probabilities_ in enumerate(predictions):
probabilities_max = max(probabilities_)
assert np.abs(probabilities_max - probabilities_check[idx]) <= 1e-3, (
model,
probabilities_max,
probabilities_check[idx],
probabilities_,
predictions_check[idx],
)
assert np.argmax(probabilities_) == predictions_check[idx], (
model,
probabilities_max,
probabilities_check[idx],
probabilities_,
predictions_check[idx],
)


def test_patch_predictor_kather100k_output(
sample_patch1: Path,
sample_patch2: Path,
) -> None:
"""Test the output of patch prediction models on Kather100K dataset."""
inputs = [Path(sample_patch1), Path(sample_patch2)]
pretrained_info = {
"alexnet-kather100k": [1.0, 0.9999735355377197],
"resnet18-kather100k": [1.0, 0.9999911785125732],
"resnet34-kather100k": [1.0, 0.9979840517044067],
"resnet50-kather100k": [1.0, 0.9999986886978149],
"resnet101-kather100k": [1.0, 0.9999932050704956],
"resnext50_32x4d-kather100k": [1.0, 0.9910059571266174],
"resnext101_32x8d-kather100k": [1.0, 0.9999971389770508],
"wide_resnet50_2-kather100k": [1.0, 0.9953408241271973],
"wide_resnet101_2-kather100k": [1.0, 0.9999831914901733],
"densenet121-kather100k": [1.0, 1.0],
"densenet161-kather100k": [1.0, 0.9999959468841553],
"densenet169-kather100k": [1.0, 0.9999934434890747],
"densenet201-kather100k": [1.0, 0.9999983310699463],
"mobilenet_v2-kather100k": [0.9999998807907104, 0.9999126195907593],
"mobilenet_v3_large-kather100k": [0.9999996423721313, 0.9999878406524658],
"mobilenet_v3_small-kather100k": [0.9999998807907104, 0.9999997615814209],
"googlenet-kather100k": [1.0, 0.9999639987945557],
}
for model, expected_prob in pretrained_info.items():
_test_predictor_output(
inputs,
model,
probabilities_check=expected_prob,
predictions_check=[6, 3],
)
# only test 1 on travis to limit runtime
if toolbox_env.running_on_ci():
break
def _extract_probabilities_from_annotation_store(dbfile: str) -> dict:
"""Helper function to extract probabilities from Annotation Store."""
probs_dict = {}
con = sqlite3.connect(dbfile)
cur = con.cursor()
annotations_properties = list(cur.execute("SELECT properties FROM annotations"))

for item in annotations_properties:
for json_str in item:
probs_dict = json.loads(json_str)
probs_dict.pop("prob_0")

return probs_dict


def _validate_probabilities(predictions: list | dict) -> bool:
Expand All @@ -330,13 +270,13 @@ def test_wsi_predictor_zarr(sample_wsi_dict: dict, tmp_path: Path) -> None:
"""Test normal run of patch predictor for WSIs."""
mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"])

predictor = PatchPredictor(
classifier = PatchPredictor(
model="alexnet-kather100k",
batch_size=32,
verbose=False,
)
# don't run test on GPU
output = predictor.run(
output = classifier.run(
images=[mini_wsi_svs],
return_probabilities=True,
return_labels=False,
Expand All @@ -357,54 +297,6 @@ def test_wsi_predictor_zarr(sample_wsi_dict: dict, tmp_path: Path) -> None:
assert _validate_probabilities(predictions=output_["probabilities"])


def test_wsi_predictor_zarr_baseline(sample_wsi_dict: dict, tmp_path: Path) -> None:
"""Test normal run of patch predictor for WSIs."""
mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"])

predictor = PatchPredictor(
model="alexnet-kather100k",
batch_size=32,
verbose=False,
)
# don't run test on GPU
output = predictor.run(
images=[mini_wsi_svs],
return_probabilities=True,
return_labels=False,
device=device,
patch_mode=False,
save_dir=tmp_path / "wsi_out_check",
units="baseline",
resolution=1.0,
)

assert output[mini_wsi_svs].exists()

output_ = zarr.open(output[mini_wsi_svs])

assert output_["probabilities"].shape == (244, 9) # number of patches x classes
assert output_["probabilities"].ndim == 2
# number of patches x [start_x, start_y, end_x, end_y]
assert output_["coordinates"].shape == (244, 4)
assert output_["coordinates"].ndim == 2
assert _validate_probabilities(predictions=output_["probabilities"])


def _extract_probabilities_from_annotation_store(dbfile: str) -> dict:
"""Helper function to extract probabilities from Annotation Store."""
probs_dict = {}
con = sqlite3.connect(dbfile)
cur = con.cursor()
annotations_properties = list(cur.execute("SELECT properties FROM annotations"))

for item in annotations_properties:
for json_str in item:
probs_dict = json.loads(json_str)
probs_dict.pop("prob_0")

return probs_dict


def test_engine_run_wsi_annotation_store(
sample_wsi_dict: dict,
tmp_path: Path,
Expand Down
14 changes: 0 additions & 14 deletions tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ class EngineABCRunParams(TypedDict, total=False):
Input IO configuration (:class:`ModelIOConfigABC`) to run the Engine.
return_labels (bool):
Whether to return the labels with the predictions.
merge_predictions (bool):
Whether to merge the predictions to form a 2-dimensional
map into a single file from a WSI.
This is only applicable if `patch_mode` is False in inference.
num_loader_workers (int):
Number of workers used in :class:`torch.utils.data.DataLoader`.
num_post_proc_workers (int):
Expand Down Expand Up @@ -165,7 +161,6 @@ class EngineABCRunParams(TypedDict, total=False):
class_dict: dict
device: str
ioconfig: ModelIOConfigABC
merge_predictions: bool
num_loader_workers: int
num_post_proc_workers: int
output_file: str
Expand Down Expand Up @@ -248,10 +243,6 @@ class EngineABC(ABC):
Runtime ioconfig.
return_labels (bool):
Whether to return the labels with the predictions.
merge_predictions (bool):
Whether to merge the predictions to form a 2-dimensional
map. This is only applicable if `patch_mode` is False in inference.
Default is False.
resolution (Resolution):
Resolution used for reading the image. Please see
:obj:`WSIReader` for details.
Expand Down Expand Up @@ -293,8 +284,6 @@ class EngineABC(ABC):
Number of workers to postprocess the results of the model.
return_labels (bool):
Whether to return the output labels. Default value is False.
merge_predictions (bool):
Whether to merge WSI predictions into a single file. Default value is False.
resolution (Resolution):
Resolution used for reading the image. Please see
:class:`WSIReader` for details.
Expand Down Expand Up @@ -374,7 +363,6 @@ def __init__(
self.cache_mode: bool = False
self.cache_size: int = self.batch_size if self.batch_size else 10000
self.labels: list | None = None
self.merge_predictions: bool = False
self.num_loader_workers = num_loader_workers
self.num_post_proc_workers = num_post_proc_workers
self.patch_input_shape: IntPair | None = None
Expand Down Expand Up @@ -1194,8 +1182,6 @@ def run(
- img_path: path of the input image.
- raw: path to save location for raw prediction,
saved in .json.
- merged: path to .npy contain merged
predictions if `merge_predictions` is `True`.

Examples:
>>> wsis = ['wsi1.svs', 'wsi2.svs']
Expand Down
Loading
Loading