Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🆕 Integrate Foundation Models Available VIA timm: UNI, Virchow, Hibou, H-optimus-0 #856

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
80 changes: 80 additions & 0 deletions tiatoolbox/models/architecture/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

import numpy as np
import timm
import torch
import torchvision.models as torch_models
from torch import nn
Expand Down Expand Up @@ -79,6 +80,45 @@ def _get_architecture(
return model.features


def _get_timm_architecture(
arch_name: str,
weights: str or WeightsEnum = "DEFAULT",
shaneahmed marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: dict,
) -> list[nn.Sequential, ...] | nn.Sequential:
"""Get architecture and weights for pathology-specific timm models.
Args:
arch_name (str):
Architecture name.
weights:
path to pretrained weights
kwargs (dict):
Key-word arguments.
Returns:
A ready-to-use timm model.
"""
if arch_name == "uni_v1":
# UNI tile encoder: https://huggingface.co/MahmoodLab/UNI
feat_extract = timm.create_model(
"hf-hub:MahmoodLab/UNI",
pretrained=True,
init_values=1e-5,
dynamic_img_size=True,
)
elif arch_name == "prov-gigapath" and timm.__version__ > "1.0.0":
# ProViT-GigaPath tile encoder: https://huggingface.co/prov-gigapath/prov-gigapath
# There is a bug in timm version 0.9.8 (https://github.com/prov-gigapath/prov-gigapath/issues/2). Version 1.0.3 has been tested successfully.
feat_extract = timm.create_model(
"hf_hub:prov-gigapath/prov-gigapath", pretrained=True
)
else:
msg = f"Architecture {arch_name} not supported. If you are loading timm models, only timm version > `1.0.3` are supported."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is too long and causing the tests to fail. If you do pre-commit install in your development environment then it will test formatting before making a commit. You can also run ruff manually in your terminal before making a commit.

Suggested change
msg = f"Architecture {arch_name} not supported. If you are loading timm models, only timm version > `1.0.3` are supported."
msg = f"Architecture {arch_name} not supported."
f"If you are loading timm models, only timm version > `1.0.3` are supported."

raise ValueError(msg)

return feat_extract


class CNNModel(ModelABC):
"""Retrieve the model backbone and attach an extra FCN to perform classification.
Expand Down Expand Up @@ -268,3 +308,43 @@ def infer_batch(
output = model(img_patches_device)
# Output should be a single tensor or scalar
return [output.cpu().numpy()]


class TimmBackbone(CNNBackbone):
"""Retrieve the pathology-specific tile encoder from timm.
This is a wrapper for pretrained models within timm.
Args:
backbone (str):
Model name. Currently, the tool supports following
model names and their default associated weights from timm.
- "uni_v1"
- "prov-gigapath"
Examples:
>>> # Creating UNI tile encoder
>>> model = TimmBackbone(backbone="uni_v1")
>>> model.eval() # set to evaluation mode
>>> # dummy sample in NHWC form
>>> samples = torch.rand(4, 3, 224, 224)
>>> features = model(samples)
>>> features.shape # feature vector
torch.Size([4, 1024])
"""

def __init__(self: TimmBackbone, backbone: str) -> None:
"""Initialize :class:`TimmBackbone`."""
super(CNNBackbone, self).__init__()
self.feat_extract = _get_timm_architecture(backbone)

def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor:
"""Pass input data through the model.
Args:
imgs (torch.Tensor):
Model input.
"""
feats = self.feat_extract(imgs)
return torch.flatten(feats, 1)
Loading