diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index 6c04c6b68..5fd930138 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -1239,7 +1239,7 @@ def test_patch_predictor_torch_compile( sample_patch2: Path, tmp_path: Path, ) -> None: - """Test torch.compile functionality. + """Test PatchPredictor with with torch.compile functionality. Args: sample_patch1 (Path): Path to sample patch 1. @@ -1247,10 +1247,8 @@ def test_patch_predictor_torch_compile( tmp_path (Path): Path to temporary directory. """ - torch_compile_enabled = rcParam["enable_torch_compile"] + torch_compile_mode = rcParam["torch_compile_mode"] torch._dynamo.reset() - rcParam["enable_torch_compile"] = True - # Test torch.compile with default mode rcParam["torch_compile_mode"] = "default" _, compile_time = timed( test_patch_predictor_api, @@ -1278,4 +1276,4 @@ def test_patch_predictor_torch_compile( ) logger.info("torch.compile max-autotune mode: %s", compile_time) torch._dynamo.reset() - rcParam["enable_torch_compile"] = torch_compile_enabled + rcParam["torch_compile_mode"] = torch_compile_mode diff --git a/tests/test_utils.py b/tests/test_utils.py index cf76028aa..d35ae31e5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,7 +19,7 @@ from shapely.geometry import Polygon from tests.test_annotation_stores import cell_polygon -from tiatoolbox import utils +from tiatoolbox import rcParam, utils from tiatoolbox.annotation.storage import SQLiteStore from tiatoolbox.models.architecture import fetch_pretrained_weights from tiatoolbox.models.architecture.utils import compile_model @@ -1825,15 +1825,35 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: def test_torch_compile_already_compiled() -> None: """Test that torch_compile does not recompile a model that is already compiled.""" - # Create a simple model + torch_compile_modes = [ + "default", + "reduce-overhead", + "max-autotune", + "max-autotune-no-cudagraphs", + ] + current_torch_compile_mode = rcParam["torch_compile_mode"] + model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10)) + + for mode in torch_compile_modes: + torch._dynamo.reset() + rcParam["torch_compile_mode"] = mode + compiled_model = compile_model(model, mode=mode) + recompiled_model = compile_model(compiled_model, mode=mode) + assert compiled_model == recompiled_model + + torch._dynamo.reset() + rcParam["torch_compile_mode"] = current_torch_compile_mode + + +def test_torch_compile_disable() -> None: + """Test torch_compile's disable mode.""" model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10)) + compiled_model = compile_model(model, mode="disable") + assert model == compiled_model - # Compile the model - compiled_model = compile_model(model) - # Compile the model again - recompiled_model = compile_model(compiled_model) +def test_torch_compile_compatibility() -> None: + """Test if torch-compile compatibility is checked correctly.""" + from tiatoolbox.models.architecture.utils import is_torch_compile_compatible - # Check that the recompiled model - # is the same as the original compiled model - assert compiled_model == recompiled_model + assert isinstance(is_torch_compile_compatible(), bool) diff --git a/tests/test_wsi_registration.py b/tests/test_wsi_registration.py index 79abd3855..4e0f07366 100644 --- a/tests/test_wsi_registration.py +++ b/tests/test_wsi_registration.py @@ -5,7 +5,10 @@ import cv2 import numpy as np import pytest +import torch +from tests.conftest import timed +from tiatoolbox import logger, rcParam from tiatoolbox.tools.registration.wsi_registration import ( AffineWSITransformer, DFBRegister, @@ -576,3 +579,70 @@ def test_affine_wsi_transformer(sample_ome_tiff: Path) -> None: expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE) assert np.sum(expected - output) == 0 + + +def test_dfbr_feature_extractor_torch_compile(dfbr_features: Path) -> None: + """Test DFBRFeatureExtractor with torch.compile functionality. + + Args: + dfbr_features (Path): Path to the expected features. + + """ + + def _extract_features() -> tuple: + dfbr = DFBRegister() + fixed_img = np.repeat( + np.expand_dims( + np.repeat( + np.expand_dims(np.arange(0, 64, 1, dtype=np.uint8), axis=1), + 64, + axis=1, + ), + axis=2, + ), + 3, + axis=2, + ) + output = dfbr.extract_features(fixed_img, fixed_img) + pool3_feat = output["block3_pool"][0, :].detach().numpy() + pool4_feat = output["block4_pool"][0, :].detach().numpy() + pool5_feat = output["block5_pool"][0, :].detach().numpy() + + return pool3_feat, pool4_feat, pool5_feat + + torch_compile_mode = rcParam["torch_compile_mode"] + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "default" + (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features) + _pool3_feat, _pool4_feat, _pool5_feat = np.load( + str(dfbr_features), + allow_pickle=True, + ) + assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4 + assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4 + assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4 + logger.info("torch.compile default mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "reduce-overhead" + (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features) + _pool3_feat, _pool4_feat, _pool5_feat = np.load( + str(dfbr_features), + allow_pickle=True, + ) + assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4 + assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4 + assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4 + logger.info("torch.compile reduce-overhead mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "max-autotune" + (pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features) + _pool3_feat, _pool4_feat, _pool5_feat = np.load( + str(dfbr_features), + allow_pickle=True, + ) + assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4 + assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4 + assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4 + logger.info("torch.compile max-autotune mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = torch_compile_mode diff --git a/tiatoolbox/__init__.py b/tiatoolbox/__init__.py index d82f785b5..80ad1b4ff 100644 --- a/tiatoolbox/__init__.py +++ b/tiatoolbox/__init__.py @@ -73,7 +73,6 @@ class _RcParam(TypedDict): TIATOOLBOX_HOME: Path pretrained_model_info: dict[str, dict] - enable_torch_compile: bool torch_compile_mode: str @@ -104,11 +103,9 @@ def read_registry_files(path_to_registry: str | Path) -> dict: "pretrained_model_info": read_registry_files( "data/pretrained_model.yaml", ), # Load a dictionary of sample files data (names and urls) - "enable_torch_compile": False, - # Disable `torch-compile`` by default "torch_compile_mode": "default", - # Set ``torch-compile`` mode to ``default`` by default - # Options: “default”, “reduce-overhead”, “max-autotune” + # Set `torch-compile` mode to `default` + # Options: `disable`, `default`, `reduce-overhead`, `max-autotune` # or “max-autotune-no-cudagraphs” } diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index 94f970df8..2150c31a1 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -12,11 +12,41 @@ from tiatoolbox import logger +def is_torch_compile_compatible() -> bool: + """Check if the current GPU is compatible with torch-compile. + + Returns: + bool: + True if the GPU is compatible with torch-compile, False + otherwise. + + """ + if torch.cuda.is_available(): # pragma: no cover + device_cap = torch.cuda.get_device_capability() + if device_cap not in ((7, 0), (8, 0), (9, 0)): + logger.warning( + "GPU is not compatible with torch.compile. " + "Compatible GPUs include NVIDIA V100, A100, and H100. " + "Speedup numbers may be lower than expected.", + stacklevel=2, + ) + return False + else: + logger.warning( + "No GPU detected or cuda not installed, " + "torch.compile is only supported on selected NVIDIA GPUs. " + "Speedup numbers may be lower than expected.", + stacklevel=2, + ) + return False + + return True # pragma: no cover + + def compile_model( model: nn.Module | None = None, *, mode: str = "default", - disable: bool = False, ) -> Callable: """A decorator to compile a model using torch-compile. @@ -25,19 +55,20 @@ def compile_model( Model to be compiled. mode (str): Mode to be used for torch-compile. Available modes are - `default`, `reduce-overhead`, `max-autotune`, and + `disable`, `default`, `reduce-overhead`, `max-autotune`, and `max-autotune-no-cudagraphs`. - disable (bool): - If True, torch-compile will be disabled. Returns: Callable: Compiled model. """ - if disable: + if mode == "disable": return model + # Check if GPU is compatible with torch.compile + is_torch_compile_compatible() + # This check will be removed when torch.compile is supported in Python 3.12+ if sys.version_info >= (3, 12): # pragma: no cover logger.warning( @@ -54,7 +85,7 @@ def compile_model( ) return model - return torch.compile(model, mode=mode, disable=disable) + return torch.compile(model, mode=mode) def centre_crop( diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 2aede1393..da4420cb0 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -255,7 +255,6 @@ def __init__( compile_model( # for runtime, such as after wrapping with nn.DataParallel model, mode=rcParam["torch_compile_mode"], - disable=not rcParam["enable_torch_compile"], ) ) self.pretrained_model = pretrained_model diff --git a/tiatoolbox/tools/registration/wsi_registration.py b/tiatoolbox/tools/registration/wsi_registration.py index 74d6a8f16..74e35ffa5 100644 --- a/tiatoolbox/tools/registration/wsi_registration.py +++ b/tiatoolbox/tools/registration/wsi_registration.py @@ -16,7 +16,8 @@ from skimage.util import img_as_float from torchvision.models import VGG16_Weights -from tiatoolbox import logger +from tiatoolbox import logger, rcParam +from tiatoolbox.models.architecture.utils import compile_model from tiatoolbox.tools.patchextraction import PatchExtractor from tiatoolbox.utils.metrics import dice from tiatoolbox.utils.transforms import imresize @@ -338,8 +339,9 @@ def __init__(self: torch.nn.Module) -> None: output_layers_id: list[str] = ["16", "23", "30"] output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"] self.features: dict = dict.fromkeys(output_layers_key, None) - self.pretrained: torch.nn.Sequential = torchvision.models.vgg16( - weights=VGG16_Weights.IMAGENET1K_V1, + self.pretrained: torch.nn.Sequential = compile_model( + torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1), + mode=rcParam["torch_compile_mode"], ).features self.f_hooks = [ getattr(self.pretrained, layer).register_forward_hook(