Skip to content

Commit

Permalink
[Feature] Support EfficientFormer. (open-mmlab#954)
Browse files Browse the repository at this point in the history
* add efficient backbone

* Update Readme and metafile

* Add unit tests

* fix confict

* fix lint

* update efficientformer head unit tests

* update README

* fix unit test

* fix Readme

* fix example

* fix typo

* recover api modification

* Update EfficiemtFormer Backbone

* fix unit tests

* add efficientformer to readme and model zoo
  • Loading branch information
Ezra-Yu committed Sep 6, 2022
1 parent 1bbb761 commit 2c6b375
Show file tree
Hide file tree
Showing 15 changed files with 1,228 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/cspnet)
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/master/configs/mvit)
- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientformer)

</details>

Expand Down
47 changes: 47 additions & 0 deletions configs/efficientformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# EfficientFormer

> [EfficientFormer: Vision Transformers at MobileNet Speed](https://arxiv.org/abs/2206.01191)
<!-- [ALGORITHM] -->

## Abstract

Vision Transformers (ViT) have shown rapid progress in computer vision tasks, achieving promising results on various benchmarks. However, due to the massive number of parameters and model design, e.g., attention mechanism, ViT-based models are generally times slower than lightweight convolutional networks. Therefore, the deployment of ViT for real-time applications is particularly challenging, especially on resource-constrained hardware such as mobile devices. Recent efforts try to reduce the computation complexity of ViT through network architecture search or hybrid design with MobileNet block, yet the inference speed is still unsatisfactory. This leads to an important question: can transformers run as fast as MobileNet while obtaining high performance? To answer this, we first revisit the network architecture and operators used in ViT-based models and identify inefficient designs. Then we introduce a dimension-consistent pure transformer (without MobileNet blocks) as a design paradigm. Finally, we perform latency-driven slimming to get a series of final models dubbed EfficientFormer. Extensive experiments show the superiority of EfficientFormer in performance and speed on mobile devices. Our fastest model, EfficientFormer-L1, achieves 79.2% top-1 accuracy on ImageNet-1K with only 1.6 ms inference latency on iPhone 12 (compiled with CoreML), which runs as fast as MobileNetV2×1.4 (1.6 ms, 74.7% top-1), and our largest model, EfficientFormer-L7, obtains 83.3% accuracy with only 7.0 ms latency. Our work proves that properly designed transformers can reach extremely low latency on mobile devices while maintaining high performance.

<div align=center>
<img src="https://user-images.githubusercontent.com/18586273/180713426-9d3d77e3-3584-42d8-9098-625b4170d796.png" width="100%"/>
</div>

## Results and models

### ImageNet-1k

| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :------------------: | :-------: | :------: | :-------: | :-------: | :---------------------------------------------------------------------: | :------------------------------------------------------------------------: |
| EfficientFormer-l1\* | 12.19 | 1.30 | 80.46 | 94.99 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l1_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l1_3rdparty_in1k_20220803-d66e61df.pth) |
| EfficientFormer-l3\* | 31.41 | 3.93 | 82.45 | 96.18 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l3_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l3_3rdparty_in1k_20220803-dde1c8c5.pth) |
| EfficientFormer-l7\* | 82.23 | 10.16 | 83.40 | 96.60 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l7_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l7_3rdparty_in1k_20220803-41a552bb.pth) |

*Models with * are converted from the [official repo](https://github.com/snap-research/EfficientFormer). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*

## Citation

```bibtex
@misc{https://doi.org/10.48550/arxiv.2206.01191,
doi = {10.48550/ARXIV.2206.01191},
url = {https://arxiv.org/abs/2206.01191},
author = {Li, Yanyu and Yuan, Geng and Wen, Yang and Hu, Eric and Evangelidis, Georgios and Tulyakov, Sergey and Wang, Yanzhi and Ren, Jian},
keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {EfficientFormer: Vision Transformers at MobileNet Speed},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
```
24 changes: 24 additions & 0 deletions configs/efficientformer/efficientformer-l1_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
_base_ = [
'../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]

model = dict(
type='ImageClassifier',
backbone=dict(
type='EfficientFormer',
arch='l1',
drop_path_rate=0,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-5)
]),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='EfficientFormerClsHead', in_channels=448, num_classes=1000))
24 changes: 24 additions & 0 deletions configs/efficientformer/efficientformer-l3_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
_base_ = [
'../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]

model = dict(
type='ImageClassifier',
backbone=dict(
type='EfficientFormer',
arch='l3',
drop_path_rate=0,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-5)
]),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='EfficientFormerClsHead', in_channels=512, num_classes=1000))
24 changes: 24 additions & 0 deletions configs/efficientformer/efficientformer-l7_8xb128_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
_base_ = [
'../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py',
]

model = dict(
type='ImageClassifier',
backbone=dict(
type='EfficientFormer',
arch='l7',
drop_path_rate=0,
init_cfg=[
dict(
type='TruncNormal',
layer=['Conv2d', 'Linear'],
std=.02,
bias=0.),
dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
dict(type='Constant', layer=['LayerScale'], val=1e-5)
]),
neck=dict(type='GlobalAveragePooling', dim=1),
head=dict(
type='EfficientFormerClsHead', in_channels=768, num_classes=1000))
67 changes: 67 additions & 0 deletions configs/efficientformer/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
Collections:
- Name: EfficientFormer
Metadata:
Training Data: ImageNet-1k
Architecture:
- Pooling
- 1x1 Convolution
- LayerScale
- MetaFormer
Paper:
URL: https://arxiv.org/pdf/2206.01191.pdf
Title: "EfficientFormer: Vision Transformers at MobileNet Speed"
README: configs/efficientformer/README.md
Code:
Version: v0.24.0
URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/backbones/efficientformer.py

Models:
- Name: efficientformer-l1_3rdparty_8xb128_in1k
Metadata:
FLOPs: 1304601088 # 1.3G
Parameters: 12278696 # 12M
In Collections: EfficientFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 80.46
Top 5 Accuracy: 94.99
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l1_3rdparty_in1k_20220803-d66e61df.pth
Config: configs/efficientformer/efficientformer-l1_8xb128_in1k.py
Converted From:
Weights: https://drive.google.com/file/d/11SbX-3cfqTOc247xKYubrAjBiUmr818y/view?usp=sharing
Code: https://github.com/snap-research/EfficientFormer
- Name: efficientformer-l3_3rdparty_8xb128_in1k
Metadata:
Training Data: ImageNet-1k
FLOPs: 3737045760 # 3.7G
Parameters: 31406000 # 31M
In Collections: EfficientFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.45
Top 5 Accuracy: 96.18
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l3_3rdparty_in1k_20220803-dde1c8c5.pth
Config: configs/efficientformer/efficientformer-l3_8xb128_in1k.py
Converted From:
Weights: https://drive.google.com/file/d/1OyyjKKxDyMj-BcfInp4GlDdwLu3hc30m/view?usp=sharing
Code: https://github.com/snap-research/EfficientFormer
- Name: efficientformer-l7_3rdparty_8xb128_in1k
Metadata:
FLOPs: 10163951616 # 10.2G
Parameters: 82229328 # 82M
In Collections: EfficientFormer
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.40
Top 5 Accuracy: 96.60
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l7_3rdparty_in1k_20220803-41a552bb.pth
Config: configs/efficientformer/efficientformer-l7_8xb128_in1k.py
Converted From:
Weights: https://drive.google.com/file/d/1cVw-pctJwgvGafeouynqWWCwgkcoFMM5/view?usp=sharing
Code: https://github.com/snap-research/EfficientFormer
1 change: 1 addition & 0 deletions docs/en/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Backbones
VAN
VGG
VisionTransformer
EfficientFormer

.. _necks:

Expand Down
3 changes: 3 additions & 0 deletions docs/en/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| MViTv2-small\* | 34.87 | 7.00 | 83.63 | 96.51 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mvit/mvitv2-small_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mvit/mvitv2-small_3rdparty_in1k_20220722-986bd741.pth) |
| MViTv2-base\* | 51.47 | 10.20 | 84.34 | 96.86 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mvit/mvitv2-base_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mvit/mvitv2-base_3rdparty_in1k_20220722-9c4f0a17.pth) |
| MViTv2-large\* | 217.99 | 42.10 | 85.25 | 97.14 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mvit/mvitv2-large_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mvit/mvitv2-large_3rdparty_in1k_20220722-2b57b983.pth) |
| EfficientFormer-l1\* | 12.19 | 1.30 | 80.46 | 94.99 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l1_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l1_3rdparty_in1k_20220803-d66e61df.pth) |
| EfficientFormer-l3\* | 31.41 | 3.93 | 82.45 | 96.18 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l3_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l3_3rdparty_in1k_20220803-dde1c8c5.pth) |
| EfficientFormer-l7\* | 82.23 | 10.16 | 83.40 | 96.60 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l7_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l7_3rdparty_in1k_20220803-41a552bb.pth) |

*Models with * are converted from other repos, others are trained by ourselves.*

Expand Down
3 changes: 2 additions & 1 deletion mmcls/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt
from .deit import DistilledVisionTransformer
from .densenet import DenseNet
from .efficientformer import EfficientFormer
from .efficientnet import EfficientNet
from .hrnet import HRNet
from .lenet import LeNet5
Expand Down Expand Up @@ -44,5 +45,5 @@
'Res2Net', 'RepVGG', 'Conformer', 'MlpMixer', 'DistilledVisionTransformer',
'PCPVT', 'SVT', 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c',
'ConvMixer', 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet',
'RepMLPNet', 'PoolFormer', 'DenseNet', 'VAN', 'MViT'
'RepMLPNet', 'PoolFormer', 'DenseNet', 'VAN', 'MViT', 'EfficientFormer'
]
Loading

0 comments on commit 2c6b375

Please sign in to comment.