Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Reproduction] Reproduce training results of DeiT. #711

Merged
merged 6 commits into from
Mar 2, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/_base_/schedules/imagenet_bs1024_adamw_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# lr = 5e-4 * 128 * 8 / 512 = 0.001
optimizer = dict(
type='AdamW',
lr=5e-4 * 128 * 8 / 512,
lr=5e-4 * 1024 / 512,
weight_decay=0.05,
eps=1e-8,
betas=(0.9, 0.999),
Expand All @@ -24,7 +24,7 @@
min_lr_ratio=1e-2,
warmup='linear',
warmup_ratio=1e-3,
warmup_iters=20 * 1252,
warmup_by_epoch=False)
warmup_iters=20,
warmup_by_epoch=True)

runner = dict(type='EpochBasedRunner', max_epochs=300)
7 changes: 4 additions & 3 deletions configs/deit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ The teacher of the distilled version DeiT is RegNetY-16GF.

| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| DeiT-tiny\* | From scratch | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
| DeiT-tiny | From scratch | 5.72 | 1.08 | 74.50 | 92.24 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.log.json) |
| DeiT-tiny distilled\* | From scratch | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) |
| DeiT-small\* | From scratch | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
| DeiT-small | From scratch | 22.05 | 4.24 | 80.69 | 95.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.log.json) |
| DeiT-small distilled\*| From scratch | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) |
| DeiT-base\* | From scratch | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
| DeiT-base | From scratch | 86.57 | 16.86 | 81.76 | 95.81 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.log.json) |
| DeiT-base\* | From scratch | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
mzr1996 marked this conversation as resolved.
Show resolved Hide resolved
| DeiT-base distilled\* | From scratch | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) |
| DeiT-base 384px\* | ImageNet-1k | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
| DeiT-base distilled 384px\* | ImageNet-1k | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) |
Expand Down
5 changes: 4 additions & 1 deletion configs/deit/deit-base_pt-16xb64_in1k.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

# model settings
model = dict(
backbone=dict(type='VisionTransformer', arch='deit-base'),
backbone=dict(
type='VisionTransformer', arch='deit-base', drop_path_rate=0.1),
head=dict(type='VisionTransformerClsHead', in_channels=768),
)

# data settings
data = dict(samples_per_gpu=64, workers_per_gpu=5)

custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
19 changes: 16 additions & 3 deletions configs/deit/deit-small_pt-4xb256_in1k.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/datasets/imagenet_bs64_swin_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

Expand All @@ -23,7 +23,20 @@
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
])
],
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
]))

# data settings
data = dict(samples_per_gpu=256, workers_per_gpu=5)

paramwise_cfg = dict(
norm_decay_mult=0.0,
bias_decay_mult=0.0,
custom_keys={
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
})
optimizer = dict(paramwise_cfg=paramwise_cfg)
35 changes: 21 additions & 14 deletions configs/deit/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,18 @@ Collections:
README: configs/deit/README.md

Models:
- Name: deit-tiny_3rdparty_pt-4xb256_in1k
- Name: deit-tiny_pt-4xb256_in1k
Metadata:
FLOPs: 1080000000
Parameters: 5720000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 72.13
Top 5 Accuracy: 91.13
Top 1 Accuracy: 74.50
Top 5 Accuracy: 92.24
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L63
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth
Config: configs/deit/deit-tiny_pt-4xb256_in1k.py
- Name: deit-tiny-distilled_3rdparty_pt-4xb256_in1k
Metadata:
Expand All @@ -45,21 +42,18 @@ Models:
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108
Config: configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py
- Name: deit-small_3rdparty_pt-4xb256_in1k
- Name: deit-small_pt-4xb256_in1k
Metadata:
FLOPs: 4240000000
Parameters: 22050000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.83
Top 5 Accuracy: 94.95
Top 1 Accuracy: 80.69
Top 5 Accuracy: 95.06
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L78
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth
Config: configs/deit/deit-small_pt-4xb256_in1k.py
- Name: deit-small-distilled_3rdparty_pt-4xb256_in1k
Metadata:
Expand All @@ -77,6 +71,19 @@ Models:
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123
Config: configs/deit/deit-small-distilled_pt-4xb256_in1k.py
- Name: deit-base_pt-16xb64_in1k
Metadata:
FLOPs: 16860000000
Parameters: 86570000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.76
Top 5 Accuracy: 95.81
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth
Config: configs/deit/deit-base_pt-16xb64_in1k.py
- Name: deit-base_3rdparty_pt-16xb64_in1k
Metadata:
FLOPs: 16860000000
Expand Down
6 changes: 3 additions & 3 deletions docs/en/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| T2T-ViT_t-24 | 64.00 | 12.69 | 82.71 | 96.09 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_8xb64_in1k_20211214-b2a68ae3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_8xb64_in1k_20211214-b2a68ae3.log.json)|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) |
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) |
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
| DeiT-tiny | 5.72 | 1.08 | 74.50 | 92.24 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.log.json) |
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) |
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
| DeiT-small | 22.05 | 4.24 | 80.69 | 95.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.log.json) |
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) |
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
| DeiT-base | 86.57 | 16.86 | 81.76 | 95.81 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.log.json) |
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) |
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) |
Expand Down