Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️Add torch.compile Functionality #716

Open
wants to merge 103 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
58dbdc2
⚡️ Add torch.compile decorators
Sep 25, 2023
29fd380
✅ Add simple compute time test
Sep 28, 2023
85d47b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
3b0667d
♻️ refactor test and add disable `torch.compile`
Sep 29, 2023
6ebaec1
Fix conflicts
Sep 29, 2023
168f8ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
c55f00b
💚 Fix CI `no_gpu` error and move timed
Sep 29, 2023
5318a37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
b3a8cb8
🎨 Minor improvements
Oct 2, 2023
6112670
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 5, 2023
e97f4a5
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 6, 2023
cdaade2
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Oct 6, 2023
dc806c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2023
f302edb
🔥 Remove `torch.compile` test for now
Oct 6, 2023
d2cf661
🔀 merge changes
Oct 6, 2023
6ab0d8f
⚡️ Add `torch.compile` to SemanticSegmentor
Oct 19, 2023
cd554fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2023
d541058
⚡️ Add `torch.compiled` to PatchPredictor
Oct 20, 2023
b660edc
Merge branch 'enhance-torch-compile'
Oct 20, 2023
3c3305f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2023
1694d17
🔥 Temp disable `torch.compile` SemanticSegmentor
Oct 20, 2023
f50b083
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Oct 20, 2023
7910039
💡 Remove chanage to `__init__.py`
Oct 20, 2023
5b513e2
🔥 Temp remove `torch.compile` SemanticSegmentor
Oct 20, 2023
aaf076e
🚨 Fix `ruff` linter errors
Oct 20, 2023
2d736c0
🚨 Temp disable cyclomatic complexity check
Oct 27, 2023
65c2c53
🚨 Cont. temp disable cyclomatic complexity check
Oct 27, 2023
9cc0168
🚨 Revert `max-args` back to 10
Nov 3, 2023
d66a7bd
✏️ Add text to notebook
Nov 8, 2023
820c7b9
⏪ Remove unnecessary line in example notebook
Nov 10, 2023
afa81f6
⚡️ Add `torch-compile` to `SemanticSegmentor`
Nov 16, 2023
e5eae50
🚧 Add 'rcParam` as config
Nov 16, 2023
6506799
🚧 Add `torch.compile` mode to `rcParam`
Nov 17, 2023
fc7120e
🐛 Fix argument mishap
Nov 17, 2023
08a3cf8
🚨 Fix linter errors
Nov 17, 2023
a13ee34
Merge branch 'develop' into enhance-torch-compile
Nov 17, 2023
1c0173d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2023
b0a276e
🐛 Fix `rcParam` definition
Nov 17, 2023
13eafe2
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Nov 17, 2023
a138bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2023
32b002e
🐛 Fix `rcParam` definition
Nov 17, 2023
33df9a9
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Nov 17, 2023
a003742
🚨 Fix `ruff` lint errors
Nov 23, 2023
df62c5d
Merge branch 'develop' into enhance-torch-compile
shaneahmed Nov 24, 2023
54d4b06
Merge branch 'develop' into enhance-torch-compile
shaneahmed Jan 19, 2024
8bc5328
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
ddfc15c
Merge branch 'develop' into enhance-torch-compile
Abdol Jan 25, 2024
c2c0e89
🚧 Supress `TorchDynamo` errors and disable `torch.compile` by default
Jan 25, 2024
957f847
🚧 Remove a problematic `torch.compile` defintion
Jan 26, 2024
f898bc5
🐛 Fix linter error about importing protected members
Jan 26, 2024
6b1520a
🚨 Further linter error fix
Jan 26, 2024
5b32712
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2024
6c6ce66
🚨 Further linter error fix
Jan 26, 2024
1574859
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Jan 26, 2024
7be70e5
Merge branch 'develop' into enhance-torch-compile
Abdol Jan 26, 2024
5d8cc6a
🚧 Remove `torch.compile` definitions from main PR
Jan 26, 2024
90d5648
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2024
9377b8d
🐛 Fix missing attribute in WSI registration from previous commit
Jan 26, 2024
b5fd57f
Merge branch 'develop' into enhance-torch-compile
shaneahmed Feb 2, 2024
8d6b788
Merge branch 'develop' into enhance-torch-compile
shaneahmed Feb 21, 2024
fb032d4
Merge branch 'develop' into enhance-torch-compile
shaneahmed Mar 15, 2024
b2f57ee
Merge branch 'develop' into enhance-torch-compile
shaneahmed Mar 19, 2024
252c7f9
⚡️ Add `torch.compile` to `PatchPredictor` (#776)
Abdol Mar 19, 2024
cf22502
Merge branch 'develop' into enhance-torch-compile
shaneahmed Mar 22, 2024
1d40585
📝 Fix docstrings
Abdol Mar 22, 2024
7d08c34
Merge branch 'develop' into enhance-torch-compile
shaneahmed Apr 23, 2024
6ee4353
Merge branch 'develop' into enhance-torch-compile
shaneahmed Apr 29, 2024
76d8e7e
Merge branch 'develop' into enhance-torch-compile
Abdol May 10, 2024
a767843
Merge branch 'develop' into enhance-torch-compile
Abdol May 14, 2024
e533b85
⚡️Refine `torch.compile` and Add to WSI Registration (#800)
Abdol Jun 14, 2024
7df7c62
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 14, 2024
d97501c
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 24, 2024
5667e54
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 25, 2024
02b8771
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 28, 2024
2729a06
Merge branch 'develop' into enhance-torch-compile
Abdol Jul 9, 2024
92b75df
Merge branch 'develop' into enhance-torch-compile
Abdol Jul 29, 2024
7cf2714
Merge branch 'develop' into enhance-torch-compile
shaneahmed Aug 9, 2024
14c9409
Merge branch 'develop' into enhance-torch-compile
shaneahmed Aug 29, 2024
a086195
Merge branch 'develop' into enhance-torch-compile
shaneahmed Sep 19, 2024
8cc2fb4
Merge branch 'develop' into enhance-torch-compile
Abdol Sep 26, 2024
ba1776e
🚧 Add `torch.compile` to SemanticSegmentor
Abdol Sep 26, 2024
dd02dd3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
36d3f81
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 2, 2024
6a5cc1d
🐛 Fix `torch.compile` assertion error
Abdol Oct 10, 2024
67fef3f
✅ Add test for SemanticSegmentor with `torch.compile`
Abdol Oct 10, 2024
24bd96b
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 18, 2024
7d1850b
Merge branch 'develop' into enhance-torch-compile
Abdol Oct 20, 2024
fd97f07
Merge branch 'develop' into enhance-torch-compile
Abdol Oct 25, 2024
e5be778
fix DeepSource error
Jiaqi-Lv Nov 1, 2024
62a9009
fix deepsource error
Jiaqi-Lv Nov 1, 2024
2d15229
Apply suggestions from code review
Abdol Nov 4, 2024
8cd748a
Update semantic_segmentor.py as per code review
Abdol Nov 11, 2024
14450f2
Apply suggestions from code review
Abdol Nov 11, 2024
dcfb18a
Merge branch 'develop' into enhance-torch-compile
Abdol Nov 11, 2024
6580a01
try fixing testcov
Jiaqi-Lv Nov 13, 2024
152d698
Adjust spacing in github workflows
shaneahmed Nov 13, 2024
11fc009
Apply suggestions from code review
Abdol Nov 15, 2024
8790efa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2024
6545920
Update utils.py to address review comments
Abdol Nov 15, 2024
5c6928f
:memo: Add `torch.compile` mode descriptions
Abdol Nov 15, 2024
77830ae
:bug: Fix E501 Line too long
shaneahmed Nov 15, 2024
8f07b0a
Merge branch 'refs/heads/develop' into enhance-torch-compile
shaneahmed Nov 15, 2024
020d9ef
:bug: Fix `test_torch_compile_compatibility`
shaneahmed Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@

import os
import shutil
import time
from pathlib import Path
from typing import Callable

import pytest
import torch

import tiatoolbox
from tiatoolbox import logger
from tiatoolbox.data import _fetch_remote_sample
from tiatoolbox.utils.env_detection import running_on_ci
from tiatoolbox.utils.env_detection import has_gpu, running_on_ci

# -------------------------------------------------------------------------------------
# Generate Parameterized Tests
Expand Down Expand Up @@ -568,3 +570,35 @@ def __exit__(self: chdir, *excinfo: object) -> None:
os.chdir(self._old_cwd.pop())

return chdir


# -------------------------------------------------------------------------------------
# Utility functions
# -------------------------------------------------------------------------------------


def timed(fn: Callable) -> (Callable, float):
"""A decorator that times the execution of a function.

Args:
fn (Callable): The function to be timed.

Returns:
A tuple containing the result of the function
and the time taken to execute it in seconds.
"""
Abdol marked this conversation as resolved.
Show resolved Hide resolved
compile_time = 0.0
if has_gpu():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
compile_time = start.elapsed_time(end) / 1000
else:
start = time.time()
result = fn()
end = time.time()
compile_time = end - start
return result, compile_time
28 changes: 26 additions & 2 deletions tests/test_wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import pytest

from tests.conftest import timed
from tiatoolbox import logger
from tiatoolbox.tools.registration.wsi_registration import (
AffineWSITransformer,
DFBRegister,
Expand All @@ -21,10 +23,32 @@
RNG = np.random.default_rng() # Numpy Random Generator


def test_extract_features(dfbr_features: Path) -> None:
def test_extract_features_time(dfbr_features: Path, test_count: int = 25) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right place to add this test?

"""Compute time test for CNN based feature extraction function."""
compile_time = 0.0
eager_compile_time = 0.0
for _ in range(test_count):
_, _compile_time = timed(lambda: test_extract_features(dfbr_features))
compile_time += _compile_time
for _ in range(test_count):
_, _compile_time = timed(
lambda: test_extract_features(dfbr_features, compiled=False),
)
eager_compile_time += _compile_time
compile_time /= test_count
eager_compile_time /= test_count
logger.info("Time taken for feature extraction (torch.compile): %f", compile_time)
logger.info(
"Time taken for feature extraction (eager execution): %f",
eager_compile_time,
)
assert compile_time < eager_compile_time


def test_extract_features(dfbr_features: Path, *, compiled: bool = True) -> None:
"""Test for CNN based feature extraction function."""
# dfbr (deep feature based registration).
dfbr = DFBRegister()
dfbr = DFBRegister(compiled=compiled)
fixed_img = np.repeat(
np.expand_dims(
np.repeat(
Expand Down
23 changes: 18 additions & 5 deletions tiatoolbox/tools/registration/wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,14 +328,18 @@ class DFBRFeatureExtractor(torch.nn.Module):

"""

def __init__(self: torch.nn.Module) -> None:
def __init__(self: torch.nn.Module, *, compiled: bool = True) -> None:
"""Initialize :class:`DFBRFeatureExtractor`."""
super().__init__()
output_layers_id: list[str] = ["16", "23", "30"]
output_layers_key: list[str] = ["block3_pool", "block4_pool", "block5_pool"]
self.features: dict = dict.fromkeys(output_layers_key, None)
self.pretrained: torch.nn.Sequential = torchvision.models.vgg16(
weights=VGG16_Weights.IMAGENET1K_V1,
self.compiled = compiled
self.pretrained: torch.nn.Sequential = torch.compile(
torchvision.models.vgg16(
weights=VGG16_Weights.IMAGENET1K_V1,
),
disable=not compiled,
).features
self.f_hooks = [
getattr(self.pretrained, layer).register_forward_hook(
Expand All @@ -356,6 +360,7 @@ def forward_hook(self: torch.nn.Module, layer_name: str) -> None:

"""

@torch.compile(disable=not self.compiled)
def hook(
_module: torch.nn.MaxPool2d,
_module_input: tuple[torch.Tensor],
Expand Down Expand Up @@ -426,11 +431,19 @@ class DFBRegister:

"""

def __init__(self: DFBRegister, patch_size: tuple[int, int] = (224, 224)) -> None:
def __init__(
self: DFBRegister,
patch_size: tuple[int, int] = (224, 224),
*,
compiled: bool = True,
) -> None:
"""Initialize :class:`DFBRegister`."""
self.patch_size = patch_size
self.x_scale, self.y_scale = [], []
self.feature_extractor = DFBRFeatureExtractor()
self.compiled = compiled
self.feature_extractor = DFBRFeatureExtractor(
compiled=compiled,
)

# Make this function private when full pipeline is implemented.
def extract_features(
Expand Down
9 changes: 8 additions & 1 deletion tiatoolbox/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
visualization,
)

from .misc import download_data, imread, imwrite, save_as_json, save_yaml, unzip_data
from .misc import (
download_data,
imread,
imwrite,
save_as_json,
save_yaml,
unzip_data,
)

__all__ = [
"imread",
Expand Down
Loading