-
Notifications
You must be signed in to change notification settings - Fork 864
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rework Pytorch Hub support code (#202)
Rework support code for torch.hub.load() to allow reusing shared functions and eventually expose more models.
- Loading branch information
1 parent
6a62615
commit 9a4564c
Showing
5 changed files
with
250 additions
and
173 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the Apache License, Version 2.0 | ||
# found in the LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the Apache License, Version 2.0 | ||
# found in the LICENSE file in the root directory of this source tree. | ||
|
||
from enum import Enum | ||
from typing import Union | ||
|
||
import torch | ||
|
||
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name | ||
|
||
|
||
class Weights(Enum): | ||
LVD142M = "LVD142M" | ||
|
||
|
||
def _make_dinov2_model( | ||
*, | ||
arch_name: str = "vit_large", | ||
img_size: int = 518, | ||
patch_size: int = 14, | ||
init_values: float = 1.0, | ||
ffn_layer: str = "mlp", | ||
block_chunks: int = 0, | ||
pretrained: bool = True, | ||
weights: Union[Weights, str] = Weights.LVD142M, | ||
**kwargs, | ||
): | ||
from ..models import vision_transformer as vits | ||
|
||
if isinstance(weights, str): | ||
try: | ||
weights = Weights[weights] | ||
except KeyError: | ||
raise AssertionError(f"Unsupported weights: {weights}") | ||
|
||
model_name = _make_dinov2_model_name(arch_name, patch_size) | ||
vit_kwargs = dict( | ||
img_size=img_size, | ||
patch_size=patch_size, | ||
init_values=init_values, | ||
ffn_layer=ffn_layer, | ||
block_chunks=block_chunks, | ||
) | ||
vit_kwargs.update(**kwargs) | ||
model = vits.__dict__[arch_name](**vit_kwargs) | ||
|
||
if pretrained: | ||
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_pretrain.pth" | ||
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") | ||
model.load_state_dict(state_dict, strict=False) | ||
|
||
return model | ||
|
||
|
||
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): | ||
""" | ||
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. | ||
""" | ||
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) | ||
|
||
|
||
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): | ||
""" | ||
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. | ||
""" | ||
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) | ||
|
||
|
||
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): | ||
""" | ||
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. | ||
""" | ||
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) | ||
|
||
|
||
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): | ||
""" | ||
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. | ||
""" | ||
return _make_dinov2_model( | ||
arch_name="vit_giant2", ffn_layer="swiglufused", weights=weights, pretrained=pretrained, **kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the Apache License, Version 2.0 | ||
# found in the LICENSE file in the root directory of this source tree. | ||
|
||
from enum import Enum | ||
from typing import Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .backbones import _make_dinov2_model | ||
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name | ||
|
||
|
||
class Weights(Enum): | ||
IMAGENET1K = "IMAGENET1K" | ||
|
||
|
||
def _make_dinov2_linear_classification_head( | ||
*, | ||
model_name: str = "dinov2_vitl14", | ||
embed_dim: int = 1024, | ||
layers: int = 4, | ||
pretrained: bool = True, | ||
weights: Union[Weights, str] = Weights.IMAGENET1K, | ||
**kwargs, | ||
): | ||
if layers not in (1, 4): | ||
raise AssertionError(f"Unsupported number of layers: {layers}") | ||
if isinstance(weights, str): | ||
try: | ||
weights = Weights[weights] | ||
except KeyError: | ||
raise AssertionError(f"Unsupported weights: {weights}") | ||
|
||
linear_head = nn.Linear((1 + layers) * embed_dim, 1_000) | ||
|
||
if pretrained: | ||
layers_str = str(layers) if layers == 4 else "" | ||
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}_linear{layers_str}_head.pth" | ||
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") | ||
linear_head.load_state_dict(state_dict, strict=False) | ||
|
||
return linear_head | ||
|
||
|
||
class _LinearClassifierWrapper(nn.Module): | ||
def __init__(self, *, backbone: nn.Module, linear_head: nn.Module, layers: int = 4): | ||
super().__init__() | ||
self.backbone = backbone | ||
self.linear_head = linear_head | ||
self.layers = layers | ||
|
||
def forward(self, x): | ||
if self.layers == 1: | ||
x = self.backbone.forward_features(x) | ||
cls_token = x["x_norm_clstoken"] | ||
patch_tokens = x["x_norm_patchtokens"] | ||
# fmt: off | ||
linear_input = torch.cat([ | ||
cls_token, | ||
patch_tokens.mean(dim=1), | ||
], dim=1) | ||
# fmt: on | ||
elif self.layers == 4: | ||
x = self.backbone.get_intermediate_layers(x, n=4, return_class_token=True) | ||
# fmt: off | ||
linear_input = torch.cat([ | ||
x[0][1], | ||
x[1][1], | ||
x[2][1], | ||
x[3][1], | ||
x[3][0].mean(dim=1), | ||
], dim=1) | ||
# fmt: on | ||
else: | ||
assert False, f"Unsupported number of layers: {self.layers}" | ||
return self.linear_head(linear_input) | ||
|
||
|
||
def _make_dinov2_linear_classifier( | ||
*, | ||
arch_name: str = "vit_large", | ||
layers: int = 4, | ||
pretrained: bool = True, | ||
weights: Union[Weights, str] = Weights.IMAGENET1K, | ||
**kwargs, | ||
): | ||
backbone = _make_dinov2_model(arch_name=arch_name, pretrained=pretrained, **kwargs) | ||
|
||
embed_dim = backbone.embed_dim | ||
patch_size = backbone.patch_size | ||
model_name = _make_dinov2_model_name(arch_name, patch_size) | ||
linear_head = _make_dinov2_linear_classification_head( | ||
model_name=model_name, | ||
embed_dim=embed_dim, | ||
layers=layers, | ||
pretrained=pretrained, | ||
weights=weights, | ||
) | ||
|
||
return _LinearClassifierWrapper(backbone=backbone, linear_head=linear_head, layers=layers) | ||
|
||
|
||
def dinov2_vits14_lc( | ||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs | ||
): | ||
""" | ||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. | ||
""" | ||
return _make_dinov2_linear_classifier( | ||
arch_name="vit_small", layers=layers, pretrained=pretrained, weights=weights, **kwargs | ||
) | ||
|
||
|
||
def dinov2_vitb14_lc( | ||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs | ||
): | ||
""" | ||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. | ||
""" | ||
return _make_dinov2_linear_classifier( | ||
arch_name="vit_base", layers=layers, pretrained=pretrained, weights=weights, **kwargs | ||
) | ||
|
||
|
||
def dinov2_vitl14_lc( | ||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs | ||
): | ||
""" | ||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. | ||
""" | ||
return _make_dinov2_linear_classifier( | ||
arch_name="vit_large", layers=layers, pretrained=pretrained, weights=weights, **kwargs | ||
) | ||
|
||
|
||
def dinov2_vitg14_lc( | ||
*, layers: int = 4, pretrained: bool = True, weights: Union[Weights, str] = Weights.IMAGENET1K, **kwargs | ||
): | ||
""" | ||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally) pretrained on the LVD-142M dataset and trained on ImageNet-1k. | ||
""" | ||
return _make_dinov2_linear_classifier( | ||
arch_name="vit_giant2", layers=layers, ffn_layer="swiglufused", pretrained=pretrained, weights=weights, **kwargs | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the Apache License, Version 2.0 | ||
# found in the LICENSE file in the root directory of this source tree. | ||
|
||
import torch.nn as nn | ||
|
||
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" | ||
|
||
|
||
def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: | ||
compact_arch_name = arch_name.replace("_", "")[:4] | ||
return f"dinov2_{compact_arch_name}{patch_size}" |
Oops, something went wrong.