Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate foundation models available through timm: UNI, Virchow, Hibou, H-optimus-0, etc. #855

Open
GeorgeBatch opened this issue Aug 30, 2024 · 1 comment · May be fixed by #856
Open

Comments

@GeorgeBatch
Copy link
Contributor

GeorgeBatch commented Aug 30, 2024

  • TIA Toolbox version: 1.5.1
  • Python version: 3.11
  • Operating System: Linux

Description

I think it would be useful to integrate pre-trained foundation models from other labs into tiatoolbox.models.architecture.vanilla.py.

Currently, the _get_architecture() function allows the use of models from torchvision.models.

But another function _get_timm_architecture() could be made to incorporate foundation models which are available from timm with weights on HuggingFace Hub. All the models from time that I've used require users to sign the licence agreement with the authors, so the licencing question seems to be solved itself since there is no way users will get access to the model weights just through Tiatoolbox without getting the access request approved by the authors first.

What I Did

To add them myself, I copied de definition of CNNBackbone changing

  1. self.feat_extract = _get_timm_architecture(backbone)
  2. removed global average pooling because given a batch of images, these pathology foundation models come ready to output a feature vector of size (batch_size, embedding_size)

class CNNBackbone(ModelABC):
"""Retrieve the model backbone and strip the classification layer.
This is a wrapper for pretrained models within pytorch.
Args:
backbone (str):
Model name. Currently, the tool supports following
model names and their default associated weights from pytorch.
- "alexnet"
- "resnet18"
- "resnet34"
- "resnet50"
- "resnet101"
- "resnext50_32x4d"
- "resnext101_32x8d"
- "wide_resnet50_2"
- "wide_resnet101_2"
- "densenet121"
- "densenet161"
- "densenet169"
- "densenet201"
- "inception_v3"
- "googlenet"
- "mobilenet_v2"
- "mobilenet_v3_large"
- "mobilenet_v3_small"
Examples:
>>> # Creating resnet50 architecture from default pytorch
>>> # without the classification layer with its associated
>>> # weights loaded
>>> model = CNNBackbone(backbone="resnet50")
>>> model.eval() # set to evaluation mode
>>> # dummy sample in NHWC form
>>> samples = torch.rand(4, 3, 512, 512)
>>> features = model(samples)
>>> features.shape # features after global average pooling
torch.Size([4, 2048])
"""
def __init__(self: CNNBackbone, backbone: str) -> None:
"""Initialize :class:`CNNBackbone`."""
super().__init__()
self.feat_extract = _get_architecture(backbone)
self.pool = nn.AdaptiveAvgPool2d((1, 1))
# pylint: disable=W0221
# because abc is generic, this is actual definition
def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor:
"""Pass input data through the model.
Args:
imgs (torch.Tensor):
Model input.
"""
feat = self.feat_extract(imgs)
gap_feat = self.pool(feat)
return torch.flatten(gap_feat, 1)
@staticmethod
def infer_batch(
model: nn.Module,
batch_data: torch.Tensor,
*,
on_gpu: bool,
) -> list[np.ndarray, ...]:
"""Run inference on an input batch.
Contains logic for forward operation as well as i/o aggregation.
Args:
model (nn.Module):
PyTorch defined model.
batch_data (torch.Tensor):
A batch of data generated by
`torch.utils.data.DataLoader`.
on_gpu (bool):
Whether to run inference on a GPU.
"""
img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type(
torch.float32,
) # to NCHW
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous()
# Inference mode
model.eval()
# Do not compute the gradient (not training)
with torch.inference_mode():
output = model(img_patches_device)
# Output should be a single tensor or scalar
return [output.cpu().numpy()]

Suggestion

Would you be interested in adding this functionality? If yes, I can make a pull request.

@shaneahmed
Copy link
Member

This would be great. Please go ahead and create a PR. You can use logger.info to explain how to access weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants