Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

[feat] Add Dino #259

Merged
merged 2 commits into from
May 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions flash/vision/backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from functools import partial
from typing import Tuple

import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, rank_zero_warn
from torch import nn as nn
Expand Down Expand Up @@ -180,3 +181,32 @@ def _fn_timm(
IMAGE_CLASSIFIER_BACKBONES(
fn=catch_url_error(partial(_fn_timm, model_name)), name=model_name, namespace="vision", package="timm"
)


# Paper: Emerging Properties in Self-Supervised Vision Transformers
# https://arxiv.org/abs/2104.14294 from Mathilde Caron and al. (29 Apr 2021)
# weights from https://github.com/facebookresearch/dino
def dino_deits16(*_, **__):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits16')
return backbone, 384


def dino_deits8(*_, **__):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_deits8')
return backbone, 384


def dino_vitb16(*_, **__):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
return backbone, 768


def dino_vitb8(*_, **__):
backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
return backbone, 768


IMAGE_CLASSIFIER_BACKBONES(dino_deits16)
IMAGE_CLASSIFIER_BACKBONES(dino_deits8)
IMAGE_CLASSIFIER_BACKBONES(dino_vitb16)
IMAGE_CLASSIFIER_BACKBONES(dino_vitb8)
8 changes: 3 additions & 5 deletions flash/vision/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,10 @@ def __init__(
self.backbone, num_features = self.backbones.get(backbone)(pretrained=pretrained, **backbone_kwargs)

head = head(num_features, num_classes) if isinstance(head, FunctionType) else head
self.head = head or nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(num_features, num_classes),
)
self.head = head or nn.Sequential(nn.Linear(num_features, num_classes), )

def forward(self, x) -> torch.Tensor:
x = self.backbone(x)
if x.dim() == 4:
x = x.mean(-1).mean(-1)
return self.head(x)
2 changes: 1 addition & 1 deletion flash_examples/finetuning/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def fn_resnet(pretrained: bool = True):
print(ImageClassifier.available_backbones())

# 4. Build the model
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
model = ImageClassifier(backbone="dino_vitb16", num_classes=datamodule.num_classes)

# 5. Create the trainer.
trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1)
Expand Down