Skip to content

Commit

Permalink
🚧 Add torch.compile to SemanticSegmentor
Browse files Browse the repository at this point in the history
  • Loading branch information
Abdol committed Sep 26, 2024
1 parent 8cc2fb4 commit ba1776e
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions tiatoolbox/models/engine/semantic_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import torch.utils.data as torch_data
import tqdm

from tiatoolbox import logger
from tiatoolbox import logger, rcParam
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.models.models_abc import IOConfigABC
from tiatoolbox.tools.patchextraction import PatchExtractor
from tiatoolbox.utils import imread, misc
Expand Down Expand Up @@ -563,7 +564,12 @@ def __init__(
self.masks = None

self.dataset_class: WSIStreamDataset = dataset_class
self.model = model
self.model = (
compile_model(
model,
mode=rcParam["torch_compile_mode"],
)
)
self.pretrained_model = pretrained_model
self.batch_size = batch_size
self.num_loader_workers = num_loader_workers
Expand Down

0 comments on commit ba1776e

Please sign in to comment.