Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️Refine torch.compile and Add to WSI Registration #800

Merged
merged 26 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
cdac067
⚡️ Add `torch.compile` to WSI registration
Mar 21, 2024
b599b66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2024
1b67690
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
shaneahmed Mar 22, 2024
5676c3a
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol Mar 22, 2024
56c4e12
✅ Add DFBR with torch.compile test
Apr 11, 2024
a831b7b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
36239b6
♻️ Refactor test
Apr 11, 2024
680461c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2024
62a7ccf
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol Apr 23, 2024
56172f7
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol Apr 29, 2024
7fd34ed
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol May 10, 2024
a93edd6
🚸 Add warning for incompatible GPUs
May 10, 2024
cdf530d
♻️ 🚸 Remove disable `torch.compile` and merge with options and enable…
May 10, 2024
663c05b
🚸 ✅ Add test for `torch.compile` and extend check for no available cuda
May 10, 2024
4f79f1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 10, 2024
e7304f1
🚸 Extend `torch.compile` for next gen GPUs
May 10, 2024
5c95702
🔀 Merge branch 'enhance-torch-compile-registration' of https://github…
May 10, 2024
f34f81c
✏️ Fix typo in `torch.compile` compatbility check
May 13, 2024
52e7e69
🚨 Potential fix to linting error when using `logger`
May 13, 2024
f030687
🚨 Potential fix to lint error (try 2)
May 13, 2024
43881b1
🐛 Move `torch.compile` comptability check function to avoid `test_log…
May 13, 2024
3d31568
✅ Move test to `test_utils`
May 13, 2024
e8335a8
🐛 ✅ Update tests for coverage and fix compatibility check condition s…
May 13, 2024
8aa3458
🐛 Fix a bug where `torch.compile` mode does not change in a test
May 14, 2024
f054d15
✅ Add a seperate test for disabling `torch.compile`
May 14, 2024
646d3da
Merge branch 'enhance-torch-compile' into enhance-torch-compile-regis…
Abdol May 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tests/models/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,18 +1239,16 @@ def test_patch_predictor_torch_compile(
sample_patch2: Path,
tmp_path: Path,
) -> None:
"""Test torch.compile functionality.
"""Test PatchPredictor with with torch.compile functionality.

Args:
sample_patch1 (Path): Path to sample patch 1.
sample_patch2 (Path): Path to sample patch 2.
tmp_path (Path): Path to temporary directory.

"""
torch_compile_enabled = rcParam["enable_torch_compile"]
torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["enable_torch_compile"] = True
# Test torch.compile with default mode
rcParam["torch_compile_mode"] = "default"
_, compile_time = timed(
test_patch_predictor_api,
Expand Down Expand Up @@ -1278,4 +1276,4 @@ def test_patch_predictor_torch_compile(
)
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["enable_torch_compile"] = torch_compile_enabled
rcParam["torch_compile_mode"] = torch_compile_mode
69 changes: 69 additions & 0 deletions tests/test_wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import cv2
import numpy as np
import pytest
import torch

from tests.conftest import timed
from tiatoolbox import logger, rcParam
from tiatoolbox.tools.registration.wsi_registration import (
AffineWSITransformer,
DFBRegister,
Expand Down Expand Up @@ -576,3 +579,69 @@ def test_affine_wsi_transformer(sample_ome_tiff: Path) -> None:
expected = cv2.rotate(expected, cv2.ROTATE_90_CLOCKWISE)

assert np.sum(expected - output) == 0


def test_dfbr_feature_extractor_torch_compile(dfbr_features: Path) -> None:
"""Test DFBRFeatureExtractor with torch.compile functionality.

Args:
dfbr_features (Path): Path to the expected features.
"""
shaneahmed marked this conversation as resolved.
Show resolved Hide resolved

def _extract_features() -> tuple:
dfbr = DFBRegister()
fixed_img = np.repeat(
np.expand_dims(
np.repeat(
np.expand_dims(np.arange(0, 64, 1, dtype=np.uint8), axis=1),
64,
axis=1,
),
axis=2,
),
3,
axis=2,
)
output = dfbr.extract_features(fixed_img, fixed_img)
pool3_feat = output["block3_pool"][0, :].detach().numpy()
pool4_feat = output["block4_pool"][0, :].detach().numpy()
pool5_feat = output["block5_pool"][0, :].detach().numpy()

return pool3_feat, pool4_feat, pool5_feat

torch_compile_mode = rcParam["torch_compile_mode"]
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "default"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
(pool3_feat, pool4_feat, pool5_feat), compile_time = timed(_extract_features)
_pool3_feat, _pool4_feat, _pool5_feat = np.load(
str(dfbr_features),
allow_pickle=True,
)
assert np.mean(np.abs(pool3_feat - _pool3_feat)) < 1.0e-4
assert np.mean(np.abs(pool4_feat - _pool4_feat)) < 1.0e-4
assert np.mean(np.abs(pool5_feat - _pool5_feat)) < 1.0e-4
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = torch_compile_mode
32 changes: 26 additions & 6 deletions tiatoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, TypedDict

import torch
import yaml

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -73,10 +74,31 @@

TIATOOLBOX_HOME: Path
pretrained_model_info: dict[str, dict]
enable_torch_compile: bool
torch_compile_mode: str


def is_torch_compile_compatible() -> bool:
"""Check if the current GPU is compatible with torch-compile.

Returns:
bool:
True if the GPU is compatible with torch-compile, False
otherwise.

"""
if torch.cuda.is_available():
device_cap = torch.cuda.get_device_capability()

Check warning on line 90 in tiatoolbox/__init__.py

View check run for this annotation

Codecov / codecov/patch

tiatoolbox/__init__.py#L90

Added line #L90 was not covered by tests
if device_cap not in ((7, 0), (8, 0), (9, 0)):
shaneahmed marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(

Check warning on line 92 in tiatoolbox/__init__.py

View check run for this annotation

Codecov / codecov/patch

tiatoolbox/__init__.py#L92

Added line #L92 was not covered by tests
"GPU is not compatible with torch.compile. "
"Compatible GPUs include NVIDIA V100, A100, and H100. "
"Speedup numbers may be lower than expected."
)
return False

Check warning on line 97 in tiatoolbox/__init__.py

View check run for this annotation

Codecov / codecov/patch

tiatoolbox/__init__.py#L97

Added line #L97 was not covered by tests

return True


def read_registry_files(path_to_registry: str | Path) -> dict:
"""Reads registry files using importlib_resources.

Expand Down Expand Up @@ -104,11 +126,9 @@
"pretrained_model_info": read_registry_files(
"data/pretrained_model.yaml",
), # Load a dictionary of sample files data (names and urls)
"enable_torch_compile": False,
# Disable `torch-compile`` by default
"torch_compile_mode": "default",
# Set ``torch-compile`` mode to ``default`` by default
# Options: “default”, “reduce-overhead”, “max-autotune”
"torch_compile_mode": "default" if is_torch_compile_compatible() else "disable",
# Set `torch-compile` mode to `default`if GPU is compatible, otherwise disable
# Options: `disable`, `default`, `reduce-overhead`, `max-autotune`
# or “max-autotune-no-cudagraphs”
}

Expand Down
14 changes: 7 additions & 7 deletions tiatoolbox/models/architecture/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
import torch
from torch import nn

from tiatoolbox import logger
from tiatoolbox import is_torch_compile_compatible, logger


def compile_model(
model: nn.Module | None = None,
*,
mode: str = "default",
disable: bool = False,
) -> Callable:
"""A decorator to compile a model using torch-compile.

Expand All @@ -25,19 +24,20 @@ def compile_model(
Model to be compiled.
mode (str):
Mode to be used for torch-compile. Available modes are
`default`, `reduce-overhead`, `max-autotune`, and
`disable`, `default`, `reduce-overhead`, `max-autotune`, and
`max-autotune-no-cudagraphs`.
disable (bool):
If True, torch-compile will be disabled.

Returns:
Callable:
Compiled model.

"""
if disable:
if mode == "disable":
return model

# Check if GPU is compatible with torch.compile
is_torch_compile_compatible()

# This check will be removed when torch.compile is supported in Python 3.12+
if sys.version_info >= (3, 12): # pragma: no cover
logger.warning(
Expand All @@ -54,7 +54,7 @@ def compile_model(
)
return model

return torch.compile(model, mode=mode, disable=disable)
return torch.compile(model, mode=mode)


def centre_crop(
Expand Down
1 change: 0 additions & 1 deletion tiatoolbox/models/engine/patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,6 @@ def __init__(
compile_model( # for runtime, such as after wrapping with nn.DataParallel
model,
mode=rcParam["torch_compile_mode"],
disable=not rcParam["enable_torch_compile"],
)
)
self.pretrained_model = pretrained_model
Expand Down
8 changes: 5 additions & 3 deletions tiatoolbox/tools/registration/wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from skimage.util import img_as_float
from torchvision.models import VGG16_Weights

from tiatoolbox import logger
from tiatoolbox import logger, rcParam
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.tools.patchextraction import PatchExtractor
from tiatoolbox.utils.metrics import dice
from tiatoolbox.utils.transforms import imresize
Expand Down Expand Up @@ -338,8 +339,9 @@ def __init__(self: torch.nn.Module) -> None:
output_layers_id: list[str] = ["16", "23", "30"]
output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"]
self.features: dict = dict.fromkeys(output_layers_key, None)
self.pretrained: torch.nn.Sequential = torchvision.models.vgg16(
weights=VGG16_Weights.IMAGENET1K_V1,
self.pretrained: torch.nn.Sequential = compile_model(
torchvision.models.vgg16(weights=VGG16_Weights.IMAGENET1K_V1),
mode=rcParam["torch_compile_mode"],
).features
self.f_hooks = [
getattr(self.pretrained, layer).register_forward_hook(
Expand Down
Loading