Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️Add torch.compile Functionality #716

Open
wants to merge 103 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 64 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
58dbdc2
⚡️ Add torch.compile decorators
Sep 25, 2023
29fd380
✅ Add simple compute time test
Sep 28, 2023
85d47b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
3b0667d
♻️ refactor test and add disable `torch.compile`
Sep 29, 2023
6ebaec1
Fix conflicts
Sep 29, 2023
168f8ef
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
c55f00b
💚 Fix CI `no_gpu` error and move timed
Sep 29, 2023
5318a37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 29, 2023
b3a8cb8
🎨 Minor improvements
Oct 2, 2023
6112670
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 5, 2023
e97f4a5
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 6, 2023
cdaade2
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Oct 6, 2023
dc806c9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2023
f302edb
🔥 Remove `torch.compile` test for now
Oct 6, 2023
d2cf661
🔀 merge changes
Oct 6, 2023
6ab0d8f
⚡️ Add `torch.compile` to SemanticSegmentor
Oct 19, 2023
cd554fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 19, 2023
d541058
⚡️ Add `torch.compiled` to PatchPredictor
Oct 20, 2023
b660edc
Merge branch 'enhance-torch-compile'
Oct 20, 2023
3c3305f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2023
1694d17
🔥 Temp disable `torch.compile` SemanticSegmentor
Oct 20, 2023
f50b083
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Oct 20, 2023
7910039
💡 Remove chanage to `__init__.py`
Oct 20, 2023
5b513e2
🔥 Temp remove `torch.compile` SemanticSegmentor
Oct 20, 2023
aaf076e
🚨 Fix `ruff` linter errors
Oct 20, 2023
2d736c0
🚨 Temp disable cyclomatic complexity check
Oct 27, 2023
65c2c53
🚨 Cont. temp disable cyclomatic complexity check
Oct 27, 2023
9cc0168
🚨 Revert `max-args` back to 10
Nov 3, 2023
d66a7bd
✏️ Add text to notebook
Nov 8, 2023
820c7b9
⏪ Remove unnecessary line in example notebook
Nov 10, 2023
afa81f6
⚡️ Add `torch-compile` to `SemanticSegmentor`
Nov 16, 2023
e5eae50
🚧 Add 'rcParam` as config
Nov 16, 2023
6506799
🚧 Add `torch.compile` mode to `rcParam`
Nov 17, 2023
fc7120e
🐛 Fix argument mishap
Nov 17, 2023
08a3cf8
🚨 Fix linter errors
Nov 17, 2023
a13ee34
Merge branch 'develop' into enhance-torch-compile
Nov 17, 2023
1c0173d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2023
b0a276e
🐛 Fix `rcParam` definition
Nov 17, 2023
13eafe2
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Nov 17, 2023
a138bc9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2023
32b002e
🐛 Fix `rcParam` definition
Nov 17, 2023
33df9a9
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Nov 17, 2023
a003742
🚨 Fix `ruff` lint errors
Nov 23, 2023
df62c5d
Merge branch 'develop' into enhance-torch-compile
shaneahmed Nov 24, 2023
54d4b06
Merge branch 'develop' into enhance-torch-compile
shaneahmed Jan 19, 2024
8bc5328
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2024
ddfc15c
Merge branch 'develop' into enhance-torch-compile
Abdol Jan 25, 2024
c2c0e89
🚧 Supress `TorchDynamo` errors and disable `torch.compile` by default
Jan 25, 2024
957f847
🚧 Remove a problematic `torch.compile` defintion
Jan 26, 2024
f898bc5
🐛 Fix linter error about importing protected members
Jan 26, 2024
6b1520a
🚨 Further linter error fix
Jan 26, 2024
5b32712
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2024
6c6ce66
🚨 Further linter error fix
Jan 26, 2024
1574859
Merge branch 'enhance-torch-compile' of https://github.com/TissueImag…
Jan 26, 2024
7be70e5
Merge branch 'develop' into enhance-torch-compile
Abdol Jan 26, 2024
5d8cc6a
🚧 Remove `torch.compile` definitions from main PR
Jan 26, 2024
90d5648
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2024
9377b8d
🐛 Fix missing attribute in WSI registration from previous commit
Jan 26, 2024
b5fd57f
Merge branch 'develop' into enhance-torch-compile
shaneahmed Feb 2, 2024
8d6b788
Merge branch 'develop' into enhance-torch-compile
shaneahmed Feb 21, 2024
fb032d4
Merge branch 'develop' into enhance-torch-compile
shaneahmed Mar 15, 2024
b2f57ee
Merge branch 'develop' into enhance-torch-compile
shaneahmed Mar 19, 2024
252c7f9
⚡️ Add `torch.compile` to `PatchPredictor` (#776)
Abdol Mar 19, 2024
cf22502
Merge branch 'develop' into enhance-torch-compile
shaneahmed Mar 22, 2024
1d40585
📝 Fix docstrings
Abdol Mar 22, 2024
7d08c34
Merge branch 'develop' into enhance-torch-compile
shaneahmed Apr 23, 2024
6ee4353
Merge branch 'develop' into enhance-torch-compile
shaneahmed Apr 29, 2024
76d8e7e
Merge branch 'develop' into enhance-torch-compile
Abdol May 10, 2024
a767843
Merge branch 'develop' into enhance-torch-compile
Abdol May 14, 2024
e533b85
⚡️Refine `torch.compile` and Add to WSI Registration (#800)
Abdol Jun 14, 2024
7df7c62
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 14, 2024
d97501c
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 24, 2024
5667e54
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 25, 2024
02b8771
Merge branch 'develop' into enhance-torch-compile
Abdol Jun 28, 2024
2729a06
Merge branch 'develop' into enhance-torch-compile
Abdol Jul 9, 2024
92b75df
Merge branch 'develop' into enhance-torch-compile
Abdol Jul 29, 2024
7cf2714
Merge branch 'develop' into enhance-torch-compile
shaneahmed Aug 9, 2024
14c9409
Merge branch 'develop' into enhance-torch-compile
shaneahmed Aug 29, 2024
a086195
Merge branch 'develop' into enhance-torch-compile
shaneahmed Sep 19, 2024
8cc2fb4
Merge branch 'develop' into enhance-torch-compile
Abdol Sep 26, 2024
ba1776e
🚧 Add `torch.compile` to SemanticSegmentor
Abdol Sep 26, 2024
dd02dd3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2024
36d3f81
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 2, 2024
6a5cc1d
🐛 Fix `torch.compile` assertion error
Abdol Oct 10, 2024
67fef3f
✅ Add test for SemanticSegmentor with `torch.compile`
Abdol Oct 10, 2024
24bd96b
Merge branch 'develop' into enhance-torch-compile
shaneahmed Oct 18, 2024
7d1850b
Merge branch 'develop' into enhance-torch-compile
Abdol Oct 20, 2024
fd97f07
Merge branch 'develop' into enhance-torch-compile
Abdol Oct 25, 2024
e5be778
fix DeepSource error
Jiaqi-Lv Nov 1, 2024
62a9009
fix deepsource error
Jiaqi-Lv Nov 1, 2024
2d15229
Apply suggestions from code review
Abdol Nov 4, 2024
8cd748a
Update semantic_segmentor.py as per code review
Abdol Nov 11, 2024
14450f2
Apply suggestions from code review
Abdol Nov 11, 2024
dcfb18a
Merge branch 'develop' into enhance-torch-compile
Abdol Nov 11, 2024
6580a01
try fixing testcov
Jiaqi-Lv Nov 13, 2024
152d698
Adjust spacing in github workflows
shaneahmed Nov 13, 2024
11fc009
Apply suggestions from code review
Abdol Nov 15, 2024
8790efa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 15, 2024
6545920
Update utils.py to address review comments
Abdol Nov 15, 2024
5c6928f
:memo: Add `torch.compile` mode descriptions
Abdol Nov 15, 2024
77830ae
:bug: Fix E501 Line too long
shaneahmed Nov 15, 2024
8f07b0a
Merge branch 'refs/heads/develop' into enhance-torch-compile
shaneahmed Nov 15, 2024
020d9ef
:bug: Fix `test_torch_compile_compatibility`
shaneahmed Nov 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy-type-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
push:
branches: [ develop, pre-release, master, main ]
pull_request:
branches: [ develop, pre-release, master, main ]
branches: [ develop, pre-release, master, main, enhance-torch-compile ]
Abdol marked this conversation as resolved.
Show resolved Hide resolved

jobs:

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
branches: [ develop, pre-release, master, main ]
tags: v*
pull_request:
branches: [ develop, pre-release, master, main ]
branches: [ develop, pre-release, master, main, enhance-torch-compile ]
Abdol marked this conversation as resolved.
Show resolved Hide resolved

jobs:
build:
Expand Down
37 changes: 36 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@

import os
import shutil
import time
from pathlib import Path
from typing import Callable

import pytest
import torch

import tiatoolbox
from tiatoolbox import logger
from tiatoolbox.data import _fetch_remote_sample
from tiatoolbox.utils.env_detection import running_on_ci
from tiatoolbox.utils.env_detection import has_gpu, running_on_ci

# -------------------------------------------------------------------------------------
# Generate Parameterized Tests
Expand Down Expand Up @@ -578,3 +580,36 @@ def data_path(tmp_path_factory: pytest.TempPathFactory) -> dict[str, object]:
(tmp_path / "slides").mkdir()
(tmp_path / "overlays").mkdir()
return {"base_path": tmp_path}


# -------------------------------------------------------------------------------------
# Utility functions
# -------------------------------------------------------------------------------------


def timed(fn: Callable, *args: object) -> (Callable, float):
"""A decorator that times the execution of a function.

Args:
fn (Callable): The function to be timed.
args (object): Arguments to be passed to the function.

Returns:
A tuple containing the result of the function
and the time taken to execute it in seconds.
"""
Abdol marked this conversation as resolved.
Show resolved Hide resolved
compile_time = 0.0
if has_gpu():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn(*args)
end.record()
torch.cuda.synchronize()
compile_time = start.elapsed_time(end) / 1000
else:
start = time.time()
result = fn(*args)
end = time.time()
compile_time = end - start
return result, compile_time
54 changes: 53 additions & 1 deletion tests/models/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import torch
from click.testing import CliRunner

from tiatoolbox import cli
from tests.conftest import timed
from tiatoolbox import cli, logger, rcParam
from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor
from tiatoolbox.models.architecture.vanilla import CNNModel
from tiatoolbox.models.dataset import (
Expand Down Expand Up @@ -1226,3 +1227,54 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
assert tmp_path.joinpath("2.merged.npy").exists()
assert tmp_path.joinpath("2.raw.json").exists()
assert tmp_path.joinpath("results.json").exists()


# -------------------------------------------------------------------------------------
# torch.compile
# -------------------------------------------------------------------------------------


def test_patch_predictor_torch_compile(
sample_patch1: Path,
sample_patch2: Path,
tmp_path: Path,
) -> None:
"""Test torch.compile functionality.

Args:
sample_patch1 (Path): Path to sample patch 1.
sample_patch2 (Path): Path to sample patch 2.
tmp_path (Path): Path to temporary directory.
"""
Abdol marked this conversation as resolved.
Show resolved Hide resolved
torch_compile_enabled = rcParam["enable_torch_compile"]
torch._dynamo.reset()
rcParam["enable_torch_compile"] = True
# Test torch.compile with default mode
rcParam["torch_compile_mode"] = "default"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["enable_torch_compile"] = torch_compile_enabled
18 changes: 18 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import numpy as np
import pandas as pd
import pytest
import torch
from PIL import Image
from requests import HTTPError
from shapely.geometry import Polygon
Expand All @@ -21,6 +22,7 @@
from tiatoolbox import utils
from tiatoolbox.annotation.storage import SQLiteStore
from tiatoolbox.models.architecture import fetch_pretrained_weights
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.utils import misc
from tiatoolbox.utils.exceptions import FileNotSupportedError
from tiatoolbox.utils.transforms import locsize2bounds
Expand Down Expand Up @@ -1819,3 +1821,19 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None:
# check correct error is raised if coordinates are missing
with pytest.raises(ValueError, match="coordinates"):
misc.dict_to_store(patch_output, (1.0, 1.0))


def test_torch_compile_already_compiled() -> None:
"""Test that torch_compile does not recompile a model that is already compiled."""
# Create a simple model
model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10))

# Compile the model
compiled_model = compile_model(model)

# Compile the model again
recompiled_model = compile_model(compiled_model)

# Check that the recompiled model
# is the same as the original compiled model
assert compiled_model == recompiled_model
8 changes: 8 additions & 0 deletions tiatoolbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class _RcParam(TypedDict):

TIATOOLBOX_HOME: Path
pretrained_model_info: dict[str, dict]
enable_torch_compile: bool
torch_compile_mode: str


def read_registry_files(path_to_registry: str | Path) -> dict:
Expand Down Expand Up @@ -102,6 +104,12 @@ def read_registry_files(path_to_registry: str | Path) -> dict:
"pretrained_model_info": read_registry_files(
"data/pretrained_model.yaml",
), # Load a dictionary of sample files data (names and urls)
"enable_torch_compile": False,
# Disable `torch-compile`` by default
"torch_compile_mode": "default",
# Set ``torch-compile`` mode to ``default`` by default
# Options: “default”, “reduce-overhead”, “max-autotune”
# or “max-autotune-no-cudagraphs”
}


Expand Down
1 change: 1 addition & 0 deletions tiatoolbox/models/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def get_pretrained_model(
model.load_state_dict(saved_state_dict, strict=True)

# !

io_info = info["ioconfig"]
creator = locate(f"tiatoolbox.models.engine.{io_info['class']}")

Expand Down
50 changes: 50 additions & 0 deletions tiatoolbox/models/architecture/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,60 @@

from __future__ import annotations

import sys
from typing import Callable

import numpy as np
import torch
from torch import nn

from tiatoolbox import logger


def compile_model(
model: nn.Module | None = None,
*,
mode: str = "default",
disable: bool = False,
) -> Callable:
"""A decorator to compile a model using torch-compile.

Args:
model (torch.nn.Module):
Model to be compiled.
mode (str):
Mode to be used for torch-compile. Available modes are
Abdol marked this conversation as resolved.
Show resolved Hide resolved
`default`, `reduce-overhead`, `max-autotune`, and
`max-autotune-no-cudagraphs`.
disable (bool):
If True, torch-compile will be disabled.

Returns:
Callable:
Compiled model.

"""
if disable:
return model

# This check will be removed when torch.compile is supported in Python 3.12+
if sys.version_info >= (3, 12): # pragma: no cover
logger.warning(
("torch-compile is currently not supported in Python 3.12+. ",),
)
return model

if isinstance(
Abdol marked this conversation as resolved.
Show resolved Hide resolved
model,
torch._dynamo.eval_frame.OptimizedModule, # skipcq: PYL-W0212 # noqa: SLF001
):
logger.warning(
Abdol marked this conversation as resolved.
Show resolved Hide resolved
("The model is already compiled. ",),
)
return model

return torch.compile(model, mode=mode, disable=disable)


def centre_crop(
img: np.ndarray | torch.tensor,
Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/models/dataset/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class WSIPatchDataset(dataset_abc.PatchDatasetABC):

"""

def __init__( # skipcq: PY-R1000 # noqa: PLR0913, PLR0915
def __init__( # skipcq: PY-R1000 # noqa: PLR0915 PLR0913
self: WSIPatchDataset,
img_path: str | Path,
mode: str = "wsi",
Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/models/engine/multi_task_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
# Python is yet to be able to natively pickle Object method/static method.
# Only top-level function is passable to multi-processing as caller.
# May need 3rd party libraries to use method/static method otherwise.
def _process_tile_predictions(
def _process_tile_predictions( # skipcq: PY-R1000
ioconfig: IOSegmentorConfig,
tile_bounds: IntBounds,
tile_flag: list,
Expand Down
11 changes: 9 additions & 2 deletions tiatoolbox/models/engine/patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import torch
import tqdm

from tiatoolbox import logger
from tiatoolbox import logger, rcParam
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset
from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig
from tiatoolbox.utils import misc, save_as_json
Expand Down Expand Up @@ -250,7 +251,13 @@ def __init__(

self.ioconfig = ioconfig # for storing original
self._ioconfig = None # for storing runtime
self.model = model # for runtime, such as after wrapping with nn.DataParallel
self.model = (
compile_model( # for runtime, such as after wrapping with nn.DataParallel
model,
mode=rcParam["torch_compile_mode"],
disable=not rcParam["enable_torch_compile"],
)
)
self.pretrained_model = pretrained_model
self.batch_size = batch_size
self.num_loader_worker = num_loader_workers
Expand Down
2 changes: 1 addition & 1 deletion tiatoolbox/models/engine/semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def __init__(
self.masks = None

self.dataset_class: WSIStreamDataset = dataset_class
self.model = model # original copy
self.model = model
self.pretrained_model = pretrained_model
self.batch_size = batch_size
self.num_loader_workers = num_loader_workers
Expand Down
4 changes: 4 additions & 0 deletions tiatoolbox/models/models_abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from typing import TYPE_CHECKING, Any, Callable

import torch
import torch._dynamo
from torch import device as torch_device

torch._dynamo.config.suppress_errors = True # skipcq: PYL-W0212 # noqa: SLF001


if TYPE_CHECKING: # pragma: no cover
from pathlib import Path

Expand Down
5 changes: 4 additions & 1 deletion tiatoolbox/tools/registration/wsi_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,10 @@ class DFBRegister:

"""

def __init__(self: DFBRegister, patch_size: tuple[int, int] = (224, 224)) -> None:
def __init__(
self: DFBRegister,
patch_size: tuple[int, int] = (224, 224),
) -> None:
"""Initialize :class:`DFBRegister`."""
self.patch_size = patch_size
self.x_scale: list[float] = []
Expand Down
9 changes: 8 additions & 1 deletion tiatoolbox/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@
visualization,
)

from .misc import download_data, imread, imwrite, save_as_json, save_yaml, unzip_data
from .misc import (
download_data,
imread,
imwrite,
save_as_json,
save_yaml,
unzip_data,
)

__all__ = [
"imread",
Expand Down
Loading