Skip to content

Commit

Permalink
✅ Add test for objective power.
Browse files Browse the repository at this point in the history
Signed-off-by: Shan E Ahmed Raza <13048456+shaneahmed@users.noreply.github.com>
  • Loading branch information
shaneahmed committed Aug 28, 2024
1 parent 94c2969 commit e4e6f22
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
40 changes: 40 additions & 0 deletions tests/engines/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,46 @@ def test_engine_run_wsi_annotation_store(
shutil.rmtree(save_dir)


def test_engine_run_wsi_annotation_store_power(
sample_wsi_dict: dict,
tmp_path: Path,
) -> None:
"""Test the engine run for Whole slide images."""
# convert to pathlib Path to prevent wsireader complaint
mini_wsi_svs = Path(sample_wsi_dict["wsi2_4k_4k_svs"])
mini_wsi_msk = Path(sample_wsi_dict["wsi2_4k_4k_msk"])

eng = PatchPredictor(model="alexnet-kather100k")

patch_size = np.array([224, 224])
save_dir = f"{tmp_path}/model_wsi_output"

kwargs = {
"patch_input_shape": patch_size,
"stride_shape": patch_size,
"resolution": 20,
"save_dir": save_dir,
"units": "power",
}

output = eng.run(
images=[mini_wsi_svs],
masks=[mini_wsi_msk],
patch_mode=False,
output_type="AnnotationStore",
**kwargs,
)

output_ = output[mini_wsi_svs]

assert output_.exists()
assert output_.suffix == ".db"
predictions = _extract_probabilities_from_annotation_store(output_)
assert _validate_probabilities(predictions)

shutil.rmtree(save_dir)


# -------------------------------------------------------------------------------------
# Command Line Interface
# -------------------------------------------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions tiatoolbox/models/engine/engine_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, floa
# in this case dataloader resolution / slide resolution will be
# equal to dataloader resolution.

if dataloader_units in ["mpp", "level", "objective_power"]:
if dataloader_units in ["mpp", "level", "power"]:
wsimeta_dict = dataloader.dataset.reader.info.as_dict()

if dataloader_units == "mpp":
Expand All @@ -1065,8 +1065,8 @@ def _calculate_scale_factor(dataloader: DataLoader) -> float | tuple[float, floa
downsample_ratio = wsimeta_dict["level_downsamples"][dataloader_resolution]
return 1.0 / downsample_ratio, 1.0 / downsample_ratio

if dataloader_resolution == "objective_power":
slide_objective_power = wsimeta_dict["power"]
if dataloader_units == "power":
slide_objective_power = wsimeta_dict["objective_power"]
return (
dataloader_resolution / slide_objective_power,
dataloader_resolution / slide_objective_power,
Expand Down

0 comments on commit e4e6f22

Please sign in to comment.