-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
21 changed files
with
620 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,24 @@ | ||
# specific to vit pretrain | ||
paramwise_cfg = dict(custom_keys={ | ||
'.cls_token': dict(decay_mult=0.0), | ||
'.pos_embed': dict(decay_mult=0.0) | ||
}) | ||
|
||
# optimizer | ||
optimizer = dict(type='AdamW', lr=0.003, weight_decay=0.3) | ||
optimizer = dict( | ||
type='AdamW', | ||
lr=0.003, | ||
weight_decay=0.3, | ||
paramwise_cfg=paramwise_cfg, | ||
) | ||
optimizer_config = dict(grad_clip=dict(max_norm=1.0)) | ||
|
||
# specific to vit pretrain | ||
paramwise_cfg = dict( | ||
custom_keys={ | ||
'.backbone.cls_token': dict(decay_mult=0.0), | ||
'.backbone.pos_embed': dict(decay_mult=0.0) | ||
}) | ||
# learning policy | ||
lr_config = dict( | ||
policy='CosineAnnealing', | ||
min_lr=0, | ||
warmup='linear', | ||
warmup_iters=10000, | ||
warmup_ratio=1e-4) | ||
warmup_ratio=1e-4, | ||
) | ||
runner = dict(type='EpochBasedRunner', max_epochs=300) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Training data-efficient image transformers & distillation through attention | ||
<!-- {DeiT} --> | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
<!-- [ABSTRACT] --> | ||
Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models. | ||
|
||
<!-- [IMAGE] --> | ||
<div align=center> | ||
<img src="https://user-images.githubusercontent.com/26739999/143225703-c287c29e-82c9-4c85-a366-dfae30d198cd.png" width="40%"/> | ||
</div> | ||
|
||
## Citation | ||
```{latex} | ||
@InProceedings{pmlr-v139-touvron21a, | ||
title = {Training data-efficient image transformers & distillation through attention}, | ||
author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve}, | ||
booktitle = {International Conference on Machine Learning}, | ||
pages = {10347--10357}, | ||
year = {2021}, | ||
volume = {139}, | ||
month = {July} | ||
} | ||
``` | ||
|
||
## Pretrained models | ||
|
||
The pre-trained models are converted from the [official repo](https://github.com/facebookresearch/deit). And the teacher of the distilled version DeiT is RegNetY-16GF. | ||
|
||
### ImageNet-1k | ||
|
||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | | ||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:| | ||
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | | ||
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | | ||
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | | ||
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | | ||
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | | ||
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | | ||
|
||
*Models with \* are converted from other repos.* | ||
|
||
## Fine-tuned models | ||
|
||
The fine-tuned models are converted from the [official repo](https://github.com/facebookresearch/deit). | ||
|
||
### ImageNet-1k | ||
|
||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | | ||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:| | ||
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) | | ||
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth) | | ||
|
||
*Models with \* are converted from other repos.* | ||
|
||
```{warning} | ||
MMClassification doesn't support training the distilled version DeiT. | ||
And we provide distilled version checkpoints for inference only. | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_base_ = './deit-base_ft-16xb32_in1k-384px.py' | ||
|
||
# model settings | ||
model = dict( | ||
backbone=dict(type='DistilledVisionTransformer'), | ||
head=dict(type='DeiTClsHead'), | ||
# Change to the path of the pretrained model | ||
# init_cfg=dict(type='Pretrained', checkpoint=''), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
_base_ = './deit-small_pt-4xb256_in1k.py' | ||
|
||
# model settings | ||
model = dict( | ||
backbone=dict(type='DistilledVisionTransformer', arch='deit-base'), | ||
head=dict(type='DeiTClsHead', in_channels=768), | ||
) | ||
|
||
# data settings | ||
data = dict(samples_per_gpu=64, workers_per_gpu=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
_base_ = [ | ||
'../_base_/datasets/imagenet_bs64_swin_384.py', | ||
'../_base_/schedules/imagenet_bs4096_AdamW.py', | ||
'../_base_/default_runtime.py' | ||
] | ||
|
||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='VisionTransformer', | ||
arch='deit-base', | ||
img_size=384, | ||
patch_size=16, | ||
), | ||
neck=None, | ||
head=dict( | ||
type='VisionTransformerClsHead', | ||
num_classes=1000, | ||
in_channels=768, | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
), | ||
# Change to the path of the pretrained model | ||
# init_cfg=dict(type='Pretrained', checkpoint=''), | ||
) | ||
|
||
# data settings | ||
data = dict(samples_per_gpu=32, workers_per_gpu=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
_base_ = './deit-small_pt-4xb256_in1k.py' | ||
|
||
# model settings | ||
model = dict( | ||
backbone=dict(type='VisionTransformer', arch='deit-base'), | ||
head=dict(type='VisionTransformerClsHead', in_channels=768), | ||
) | ||
|
||
# data settings | ||
data = dict(samples_per_gpu=64, workers_per_gpu=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
_base_ = './deit-small_pt-4xb256_in1k.py' | ||
|
||
# model settings | ||
model = dict( | ||
backbone=dict(type='DistilledVisionTransformer', arch='deit-small'), | ||
head=dict(type='DeiTClsHead', in_channels=384), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
_base_ = [ | ||
'../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py', | ||
'../_base_/schedules/imagenet_bs4096_AdamW.py', | ||
'../_base_/default_runtime.py' | ||
] | ||
|
||
# model settings | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='VisionTransformer', | ||
arch='deit-small', | ||
img_size=224, | ||
patch_size=16), | ||
neck=None, | ||
head=dict( | ||
type='VisionTransformerClsHead', | ||
num_classes=1000, | ||
in_channels=384, | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
), | ||
init_cfg=[ | ||
dict(type='TruncNormal', layer='Linear', std=.02), | ||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.), | ||
]) | ||
|
||
# data settings | ||
data = dict(samples_per_gpu=256, workers_per_gpu=5) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
_base_ = './deit-small_pt-4xb256_in1k.py' | ||
|
||
# model settings | ||
model = dict( | ||
backbone=dict(type='DistilledVisionTransformer', arch='deit-tiny'), | ||
head=dict(type='DeiTClsHead', in_channels=192), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
_base_ = './deit-small_pt-4xb256_in1k.py' | ||
|
||
# model settings | ||
model = dict( | ||
backbone=dict(type='VisionTransformer', arch='deit-tiny'), | ||
head=dict(type='VisionTransformerClsHead', in_channels=192), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
Collections: | ||
- Name: DeiT | ||
Metadata: | ||
Training Data: ImageNet-1k | ||
Architecture: | ||
- Layer Normalization | ||
- Scaled Dot-Product Attention | ||
- Attention Dropout | ||
- Multi-Head Attention | ||
Paper: | ||
URL: https://arxiv.org/abs/2012.12877 | ||
Title: "Training data-efficient image transformers & distillation through attention" | ||
README: configs/deit/README.md | ||
|
||
Models: | ||
- Name: deit-tiny_3rdparty_pt-4xb256_in1k | ||
Metadata: | ||
FLOPs: 1080000000 | ||
Parameters: 5720000 | ||
In Collection: DeiT | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 72.13 | ||
Top 5 Accuracy: 91.13 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth | ||
Converted From: | ||
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth | ||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L63 | ||
Config: configs/deit/deit-tiny_pt-4xb256_in1k.py | ||
- Name: deit-tiny-distilled_3rdparty_pt-4xb256_in1k | ||
Metadata: | ||
FLOPs: 1080000000 | ||
Parameters: 5720000 | ||
In Collection: DeiT | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 74.51 | ||
Top 5 Accuracy: 91.90 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth | ||
Converted From: | ||
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth | ||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108 | ||
Config: configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py | ||
- Name: deit-small_3rdparty_pt-4xb256_in1k | ||
Metadata: | ||
FLOPs: 4240000000 | ||
Parameters: 22050000 | ||
In Collection: DeiT | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 79.83 | ||
Top 5 Accuracy: 94.95 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth | ||
Converted From: | ||
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth | ||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L78 | ||
Config: configs/deit/deit-small_pt-4xb256_in1k.py | ||
- Name: deit-small-distilled_3rdparty_pt-4xb256_in1k | ||
Metadata: | ||
FLOPs: 4240000000 | ||
Parameters: 22050000 | ||
In Collection: DeiT | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 81.17 | ||
Top 5 Accuracy: 95.40 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth | ||
Converted From: | ||
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth | ||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123 | ||
Config: configs/deit/deit-small-distilled_pt-4xb256_in1k.py | ||
- Name: deit-base_3rdparty_pt-16xb64_in1k | ||
Metadata: | ||
FLOPs: 16860000000 | ||
Parameters: 86570000 | ||
In Collection: DeiT | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 81.79 | ||
Top 5 Accuracy: 95.59 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth | ||
Converted From: | ||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth | ||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L93 | ||
Config: configs/deit/deit-base_pt-16xb64_in1k.py | ||
- Name: deit-base-distilled_3rdparty_pt-16xb64_in1k | ||
Metadata: | ||
FLOPs: 16860000000 | ||
Parameters: 86570000 | ||
In Collection: DeiT | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 83.33 | ||
Top 5 Accuracy: 96.49 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth | ||
Converted From: | ||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth | ||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L138 | ||
Config: configs/deit/deit-base-distilled_pt-16xb64_in1k.py | ||
- Name: deit-base_3rdparty_ft-16xb32_in1k-384px | ||
Metadata: | ||
FLOPs: 49370000000 | ||
Parameters: 86860000 | ||
In Collection: DeiT | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 83.04 | ||
Top 5 Accuracy: 96.31 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth | ||
Converted From: | ||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth | ||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L153 | ||
Config: configs/deit/deit-base_ft-16xb32_in1k-384px.py | ||
- Name: deit-base-distilled_3rdparty_ft-16xb32_in1k-384px | ||
Metadata: | ||
FLOPs: 49370000000 | ||
Parameters: 86860000 | ||
In Collection: DeiT | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Metrics: | ||
Top 1 Accuracy: 85.55 | ||
Top 5 Accuracy: 97.35 | ||
Task: Image Classification | ||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth | ||
Converted From: | ||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth | ||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L168 | ||
Config: configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.