Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️Refine torch.compile and Add to WSI Registration #800

Merged
merged 26 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
cdac067
⚡️ Add `torch.compile` to WSI registration
Mar 21, 2024
b599b66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2024
1b67690
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
shaneahmed Mar 22, 2024
5676c3a
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol Mar 22, 2024
56c4e12
✅ Add DFBR with torch.compile test
Apr 11, 2024
a831b7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
36239b6
♻️ Refactor test
Apr 11, 2024
680461c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
62a7ccf
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol Apr 23, 2024
56172f7
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol Apr 29, 2024
7fd34ed
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol May 10, 2024
a93edd6
🚸 Add warning for incompatible GPUs
May 10, 2024
cdf530d
♻️ 🚸 Remove disable `torch.compile` and merge with options and enable…
May 10, 2024
663c05b
🚸 ✅ Add test for `torch.compile` and extend check for no available cuda
May 10, 2024
4f79f1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2024
e7304f1
🚸 Extend `torch.compile` for next gen GPUs
May 10, 2024
5c95702
🔀 Merge branch 'enhance-torch-compile-registration' of https://github…
May 10, 2024
f34f81c
✏️ Fix typo in `torch.compile` compatbility check
May 13, 2024
52e7e69
🚨 Potential fix to linting error when using `logger`
May 13, 2024
f030687
🚨 Potential fix to lint error (try 2)
May 13, 2024
43881b1
🐛 Move `torch.compile` comptability check function to avoid `test_log…
May 13, 2024
3d31568
✅ Move test to `test_utils`
May 13, 2024
e8335a8
🐛 ✅ Update tests for coverage and fix compatibility check condition s…
May 13, 2024
8aa3458
🐛 Fix a bug where `torch.compile` mode does not change in a test
May 14, 2024
f054d15
✅ Add a seperate test for disabling `torch.compile`
May 14, 2024
646d3da
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/models/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,18 +1239,16 @@ def test_patch_predictor_torch_compile(
sample_patch2: Path,
tmp_path: Path,
) -> None:
"""Test torch.compile functionality.
"""Test PatchPredictor with with torch.compile functionality.

Args:
sample_patch1 (Path): Path to sample patch 1.
sample_patch2 (Path): Path to sample patch 2.
tmp_path (Path): Path to temporary directory.

"""
torch_compile_enabled = rcParam["enable_torch_compile"]
torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["enable_torch_compile"] = True
# Test torch.compile with default mode
rcParam["torch_compile_mode"] = "default"
_, compile_time = timed(
test_patch_predictor_api,
Expand Down Expand Up @@ -1278,4 +1276,4 @@ def test_patch_predictor_torch_compile(
)
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["enable_torch_compile"] = torch_compile_enabled
rcParam["torch_compile_mode"] = torch_compile_mode
38 changes: 29 additions & 9 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from shapely.geometry import Polygon

from tests.test_annotation_stores import cell_polygon
from tiatoolbox import utils
from tiatoolbox import rcParam, utils
from tiatoolbox.annotation.storage import SQLiteStore
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.utils import compile_model
Expand Down Expand Up @@ -1825,15 +1825,35 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None:

def test_torch_compile_already_compiled() -> None:
"""Test that torch_compile does not recompile a model that is already compiled."""
# Create a simple model
torch_compile_modes = [
"default",
"reduce-overhead",
"max-autotune",
"max-autotune-no-cudagraphs",
]
current_torch_compile_mode = rcParam["torch_compile_mode"]
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))

for mode in torch_compile_modes:
torch._dynamo.reset()
rcParam["torch_compile_mode"] = mode
compiled_model = compile_model(model, mode=mode)
recompiled_model = compile_model(compiled_model, mode=mode)
assert compiled_model == recompiled_model

torch._dynamo.reset()
rcParam["torch_compile_mode"] = current_torch_compile_mode


def test_torch_compile_disable() -> None:
"""Test torch_compile's disable mode."""
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))
compiled_model = compile_model(model, mode="disable")
assert model == compiled_model

# Compile the model
compiled_model = compile_model(model)

# Compile the model again
recompiled_model = compile_model(compiled_model)
def test_torch_compile_compatibility() -> None:
"""Test if torch-compile compatibility is checked correctly."""
from tiatoolbox.models.architecture.utils import is_torch_compile_compatible

# Check that the recompiled model
# is the same as the original compiled model
assert compiled_model == recompiled_model
assert isinstance(is_torch_compile_compatible(), bool)
70 changes: 70 additions & 0 deletions tests/test_wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import cv2
import numpy as np
import pytest
import torch

from tests.conftest import timed
from tiatoolbox import logger, rcParam
from tiatoolbox.tools.registration.wsi_registration import (
AffineWSITransformer,
DFBRegister,
Expand Down Expand Up @@ -576,3 +579,70 @@ def test_affine_wsi_transformer(sample_ome_tiff: Path) -> None:
expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE)

assert np.sum(expected - output) == 0


def test_dfbr_feature_extractor_torch_compile(dfbr_features: Path) -> None:
"""Test DFBRFeatureExtractor with torch.compile functionality.
Args:
dfbr_features (Path): Path to the expected features.
"""
shaneahmed marked this conversation as resolved.
Show resolved Hide resolved

def _extract_features() -> tuple:
dfbr = DFBRegister()
fixed_img = np.repeat(
np.expand_dims(
np.repeat(
np.expand_dims(np.arange(0, 64, 1, dtype=np.uint8), axis=1),
64,
axis=1,
),
axis=2,
),
3,
axis=2,
)
output = dfbr.extract_features(fixed_img, fixed_img)
pool3_feat = output["block3_pool"][0, :].detach().numpy()
pool4_feat = output["block4_pool"][0, :].detach().numpy()
pool5_feat = output["block5_pool"][0, :].detach().numpy()

return pool3_feat, pool4_feat, pool5_feat

torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "default"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = torch_compile_mode
7 changes: 2 additions & 5 deletions tiatoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ class _RcParam(TypedDict):

TIATOOLBOX_HOME: Path
pretrained_model_info: dict[str, dict]
enable_torch_compile: bool
torch_compile_mode: str


Expand Down Expand Up @@ -104,11 +103,9 @@ def read_registry_files(path_to_registry: str | Path) -> dict:
"pretrained_model_info": read_registry_files(
"data/pretrained_model.yaml",
), # Load a dictionary of sample files data (names and urls)
"enable_torch_compile": False,
# Disable `torch-compile`` by default
"torch_compile_mode": "default",
# Set ``torch-compile`` mode to ``default`` by default
# Options: default”, “reduce-overhead”, “max-autotune
# Set `torch-compile` mode to `default`
# Options: `disable`, `default`, `reduce-overhead`, `max-autotune`
# or “max-autotune-no-cudagraphs”
}

Expand Down
43 changes: 37 additions & 6 deletions tiatoolbox/models/architecture/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,41 @@
from tiatoolbox import logger


def is_torch_compile_compatible() -> bool:
"""Check if the current GPU is compatible with torch-compile.

Returns:
bool:
True if the GPU is compatible with torch-compile, False
otherwise.

"""
if torch.cuda.is_available(): # pragma: no cover
device_cap = torch.cuda.get_device_capability()
if device_cap not in ((7, 0), (8, 0), (9, 0)):
logger.warning(
"GPU is not compatible with torch.compile. "
"Compatible GPUs include NVIDIA V100, A100, and H100. "
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
return False
else:
logger.warning(
"No GPU detected or cuda not installed, "
"torch.compile is only supported on selected NVIDIA GPUs. "
"Speedup numbers may be lower than expected.",
stacklevel=2,
)
return False

return True # pragma: no cover


def compile_model(
model: nn.Module | None = None,
*,
mode: str = "default",
disable: bool = False,
) -> Callable:
"""A decorator to compile a model using torch-compile.

Expand All @@ -25,19 +55,20 @@ def compile_model(
Model to be compiled.
mode (str):
Mode to be used for torch-compile. Available modes are
`default`, `reduce-overhead`, `max-autotune`, and
`disable`, `default`, `reduce-overhead`, `max-autotune`, and
`max-autotune-no-cudagraphs`.
disable (bool):
If True, torch-compile will be disabled.

Returns:
Callable:
Compiled model.

"""
if disable:
if mode == "disable":
return model

# Check if GPU is compatible with torch.compile
is_torch_compile_compatible()

# This check will be removed when torch.compile is supported in Python 3.12+
if sys.version_info >= (3, 12): # pragma: no cover
logger.warning(
Expand All @@ -54,7 +85,7 @@ def compile_model(
)
return model

return torch.compile(model, mode=mode, disable=disable)
return torch.compile(model, mode=mode)


def centre_crop(
Expand Down
1 change: 0 additions & 1 deletion tiatoolbox/models/engine/patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def __init__(
compile_model( # for runtime, such as after wrapping with nn.DataParallel
model,
mode=rcParam["torch_compile_mode"],
disable=not rcParam["enable_torch_compile"],
)
)
self.pretrained_model = pretrained_model
Expand Down
8 changes: 5 additions & 3 deletions tiatoolbox/tools/registration/wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from skimage.util import img_as_float
from torchvision.models import VGG16_Weights

from tiatoolbox import logger
from tiatoolbox import logger, rcParam
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.tools.patchextraction import PatchExtractor
from tiatoolbox.utils.metrics import dice
from tiatoolbox.utils.transforms import imresize
Expand Down Expand Up @@ -338,8 +339,9 @@ def __init__(self: torch.nn.Module) -> None:
output_layers_id: list[str] = ["16", "23", "30"]
output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"]
self.features: dict = dict.fromkeys(output_layers_key, None)
self.pretrained: torch.nn.Sequential = torchvision.models.vgg16(
weights=VGG16_Weights.IMAGENET1K_V1,
self.pretrained: torch.nn.Sequential = compile_model(
torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1),
mode=rcParam["torch_compile_mode"],
).features
self.f_hooks = [
getattr(self.pretrained, layer).register_forward_hook(
Expand Down
Loading