Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🆕 Integrate Foundation Models Available VIA timm: UNI, Prov-GigaPath, H-optimus-0 #856

Merged
Changes from 2 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
5bf228f
add `_get_timm_architecture()` and `TimmBackbone`
GeorgeBatch Sep 2, 2024
1bc3424
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 2, 2024
f1ec821
inherit `TimmBackbone` from `CNNBackbone`
GeorgeBatch Sep 3, 2024
d823c1e
update link: change GitHub to HuggingDace for UNI model
GeorgeBatch Sep 3, 2024
7aa3ca3
Merge branch 'develop' into enhance-add-timm-feature-extractors
shaneahmed Sep 20, 2024
934bd9a
Merge branch 'TissueImageAnalytics:develop' into enhance-add-timm-fea…
GeorgeBatch Sep 24, 2024
c4bbb40
Apply suggestions from code review: prov-gigapath version of timm
GeorgeBatch Sep 24, 2024
2c9a47f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2024
6ec9ca1
remove explicit assert statement re prov-gigapath version of `timm`
GeorgeBatch Sep 24, 2024
5445ff5
Merge branch 'TissueImageAnalytics:develop' into enhance-add-timm-fea…
GeorgeBatch Oct 16, 2024
142eaf6
remove unused arguments in ; fix formatting
GeorgeBatch Oct 16, 2024
36126c5
remove unused arguments from docstring of _get_timm_architecture
GeorgeBatch Oct 16, 2024
1b666b8
improve error message in _get_timm_architecture()
GeorgeBatch Oct 16, 2024
9c494a2
simplify inheretance code of TimmBackbone
GeorgeBatch Oct 16, 2024
5b5f674
add TimmModel class inheriting from CNNModel
GeorgeBatch Oct 16, 2024
7f51fec
add tests for TimmModel - fail because backbone argument in not speci…
GeorgeBatch Oct 16, 2024
40f7f95
TEMPORARY FIX: pass backbone="alexnet" for TimmModel and TimmBackbone…
GeorgeBatch Oct 16, 2024
7ec12b7
Make consistent with `_get_architecture` error message
GeorgeBatch Oct 17, 2024
4eb7f8c
fix `__init__` methods of `TimmModel` and `TimmBackbone`
GeorgeBatch Oct 17, 2024
7b0c0e8
add `timm>=1.0.3` as a requirement, otherwise the build is failing
GeorgeBatch Oct 17, 2024
deea74f
introduce a keyword-only `pretrained` argument for `timm` models
GeorgeBatch Oct 17, 2024
3437c1a
add support for `H-optimus-0`
GeorgeBatch Oct 17, 2024
a47300c
add a comment with HuggingFace link for H-Optimus-0
GeorgeBatch Oct 17, 2024
373ee8a
replace "uni_v1" with "UNI" for consistency with HuggingFace name
GeorgeBatch Oct 17, 2024
3263e3c
add "efficientnet_b{i}" for i in [0, 1, ..., 7]
GeorgeBatch Oct 18, 2024
d037dda
fix typo and shorten ProViT-GigaPath comments
GeorgeBatch Oct 18, 2024
e2f4a9d
Merge branch 'TissueImageAnalytics:develop' into enhance-add-timm-fea…
GeorgeBatch Oct 19, 2024
6afecd2
:fire: No need to check for timm version
shaneahmed Oct 24, 2024
2b5f9ce
:fire: No need to check for timm version
shaneahmed Oct 24, 2024
dad9f8e
:recycle: Define `postproc` and `infer_batch` used commonly by CNNMod…
shaneahmed Oct 24, 2024
f893a83
:recycle: Refactor `CNNBackBone` and `TimmBackbone`.
shaneahmed Oct 24, 2024
4293044
:bug: Fix deepsource error
shaneahmed Oct 24, 2024
cc31dc7
:pushpin: Pin `numpy` version <2.0
shaneahmed Oct 24, 2024
9ae4d6e
:white_check_mark: Skip Coverage for `UNI`, `Prov-GigaPath`, `H-Opti…
shaneahmed Oct 24, 2024
0166e61
Merge branch 'TissueImageAnalytics:develop' into enhance-add-timm-fea…
GeorgeBatch Oct 24, 2024
617c93f
🔥 No need to check for timm version
GeorgeBatch Oct 25, 2024
d30d7d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2024
d08127f
add test for TimmBackbone
GeorgeBatch Oct 29, 2024
85ad67b
show how thumbnails and masks can be saved using `tiatoolbox.utils.mi…
GeorgeBatch Nov 4, 2024
961a9ba
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 4, 2024
18070c5
fix typo
GeorgeBatch Nov 4, 2024
c270d7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 4, 2024
ad9053d
add Google Colab link
GeorgeBatch Nov 5, 2024
9eebbb4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2024
dabf9a9
add TimmBackbone example as a comment under `model = CNNBackbone("res…
GeorgeBatch Nov 5, 2024
5c93e9a
add TimmBackbone example as a comment under model = CNNBackbone("resn…
GeorgeBatch Nov 5, 2024
1ab884e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2024
276da9c
Attempt to fix ruff error: commented code (TimmBackbone)
GeorgeBatch Nov 5, 2024
eba8d16
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2024
6450d01
Attempt to fix ruff error: commented code (TimmBackbone) - inference …
GeorgeBatch Nov 5, 2024
ce2ef64
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2024
3631b34
[skip ci] :memo: Update 03-tissue-masking.ipynb
shaneahmed Nov 15, 2024
6f476c1
Merge branch 'develop' into enhance-add-timm-feature-extractors
shaneahmed Nov 15, 2024
c159182
[skip ci] :memo: Update the jupyter notebooks.
shaneahmed Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions tiatoolbox/models/architecture/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

import numpy as np
import timm
import torch
import torchvision.models as torch_models
from torch import nn
Expand Down Expand Up @@ -79,6 +80,48 @@ def _get_architecture(
return model.features


def _get_timm_architecture(
arch_name: str,
weights: str or WeightsEnum = "DEFAULT",
shaneahmed marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: dict,
) -> list[nn.Sequential, ...] | nn.Sequential:
"""Get architecture and weights for pathology-specific timm models.

Args:
arch_name (str):
Architecture name.
weights:
path to pretrained weights
kwargs (dict):
Key-word arguments.

Returns:
A ready-to-use timm model.
"""
if arch_name == "uni_v1":
# UNI tile encoder: https://github.com/mahmoodlab/UNI
feat_extract = timm.create_model(
"hf-hub:MahmoodLab/UNI",
pretrained=True,
init_values=1e-5,
dynamic_img_size=True,
)
elif arch_name == "prov-gigapath":
GeorgeBatch marked this conversation as resolved.
Show resolved Hide resolved
# ProViT-GigaPath tile encoder: https://huggingface.co/prov-gigapath/prov-gigapath
# does not work with timm==0.9.8, needs timm==1.0.3: https://github.com/prov-gigapath/prov-gigapath/issues/2
assert (
GeorgeBatch marked this conversation as resolved.
Show resolved Hide resolved
timm.__version__ > "1.0.0"
), "There is a bug in version `timm==0.9.8`. Tested to work from version `timm==1.0.3`"
feat_extract = timm.create_model(
"hf_hub:prov-gigapath/prov-gigapath", pretrained=True
)
else:
msg = f"Architecture {arch_name} not supported"
GeorgeBatch marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(msg)

return feat_extract


class CNNModel(ModelABC):
"""Retrieve the model backbone and attach an extra FCN to perform classification.

Expand Down Expand Up @@ -268,3 +311,77 @@ def infer_batch(
output = model(img_patches_device)
# Output should be a single tensor or scalar
return [output.cpu().numpy()]


class TimmBackbone(ModelABC):
"""Retrieve the pathology-specific tile encoder from timm.

This is a wrapper for pretrained models within timm.

Args:
backbone (str):
Model name. Currently, the tool supports following
model names and their default associated weights from timm.
- "uni_v1"
- "prov-gigapath"

Examples:
>>> # Creating UNI tile encoder
>>> model = TimmBackbone(backbone="uni_v1")
>>> model.eval() # set to evaluation mode
>>> # dummy sample in NHWC form
>>> samples = torch.rand(4, 3, 224, 224)
>>> features = model(samples)
>>> features.shape # feature vector
torch.Size([4, 1024])
"""

def __init__(self: TimmBackbone, backbone: str) -> None:
"""Initialize :class:`TimmBackbone`."""
super().__init__()
self.feat_extract = _get_timm_architecture(backbone)

def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor:
"""Pass input data through the model.

Args:
imgs (torch.Tensor):
Model input.

"""
feats = self.feat_extract(imgs)
return torch.flatten(feats, 1)

@staticmethod
def infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
*,
on_gpu: bool,
) -> list[np.ndarray, ...]:
"""Run inference on an input batch.

Contains logic for forward operation as well as i/o aggregation.

Args:
model (nn.Module):
PyTorch defined model.
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
on_gpu (bool):
Whether to run inference on a GPU.

"""
img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type(
torch.float32,
) # to NCHW
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()

# Inference mode
model.eval()
# Do not compute the gradient (not training)
with torch.inference_mode():
output = model(img_patches_device)
# Output should be a single tensor or scalar
return [output.cpu().numpy()]
Loading