Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🆕 Integrate Foundation Models Available VIA timm: UNI, Virchow, Hibou, H-optimus-0 #856

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
83 changes: 83 additions & 0 deletions tiatoolbox/models/architecture/vanilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING

import numpy as np
import timm
import torch
import torchvision.models as torch_models
from torch import nn
Expand Down Expand Up @@ -79,6 +80,48 @@ def _get_architecture(
return model.features


def _get_timm_architecture(
arch_name: str,
weights: str or WeightsEnum = "DEFAULT",
shaneahmed marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: dict,
) -> list[nn.Sequential, ...] | nn.Sequential:
"""Get architecture and weights for pathology-specific timm models.

Args:
arch_name (str):
Architecture name.
weights:
path to pretrained weights
kwargs (dict):
Key-word arguments.

Returns:
A ready-to-use timm model.
"""
if arch_name == "uni_v1":
# UNI tile encoder: https://huggingface.co/MahmoodLab/UNI
feat_extract = timm.create_model(
"hf-hub:MahmoodLab/UNI",
pretrained=True,
init_values=1e-5,
dynamic_img_size=True,
)
elif arch_name == "prov-gigapath":
GeorgeBatch marked this conversation as resolved.
Show resolved Hide resolved
# ProViT-GigaPath tile encoder: https://huggingface.co/prov-gigapath/prov-gigapath
# does not work with timm==0.9.8, needs timm==1.0.3: https://github.com/prov-gigapath/prov-gigapath/issues/2
assert (
GeorgeBatch marked this conversation as resolved.
Show resolved Hide resolved
timm.__version__ > "1.0.0"
), "There is a bug in version `timm==0.9.8`. Tested to work from version `timm==1.0.3`"
feat_extract = timm.create_model(
"hf_hub:prov-gigapath/prov-gigapath", pretrained=True
)
else:
msg = f"Architecture {arch_name} not supported"
GeorgeBatch marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(msg)

return feat_extract


class CNNModel(ModelABC):
"""Retrieve the model backbone and attach an extra FCN to perform classification.

Expand Down Expand Up @@ -268,3 +311,43 @@ def infer_batch(
output = model(img_patches_device)
# Output should be a single tensor or scalar
return [output.cpu().numpy()]


class TimmBackbone(CNNBackbone):
"""Retrieve the pathology-specific tile encoder from timm.

This is a wrapper for pretrained models within timm.

Args:
backbone (str):
Model name. Currently, the tool supports following
model names and their default associated weights from timm.
- "uni_v1"
- "prov-gigapath"

Examples:
>>> # Creating UNI tile encoder
>>> model = TimmBackbone(backbone="uni_v1")
>>> model.eval() # set to evaluation mode
>>> # dummy sample in NHWC form
>>> samples = torch.rand(4, 3, 224, 224)
>>> features = model(samples)
>>> features.shape # feature vector
torch.Size([4, 1024])
"""

def __init__(self: TimmBackbone, backbone: str) -> None:
"""Initialize :class:`TimmBackbone`."""
super(CNNBackbone, self).__init__()
self.feat_extract = _get_timm_architecture(backbone)

def forward(self: TimmBackbone, imgs: torch.Tensor) -> torch.Tensor:
"""Pass input data through the model.

Args:
imgs (torch.Tensor):
Model input.

"""
feats = self.feat_extract(imgs)
return torch.flatten(feats, 1)
Loading