From 2c6b375931d625a96f384d74ae30e3a4b2c399e6 Mon Sep 17 00:00:00 2001
From: Ezra-Yu <1105212286@qq.com>
Date: Tue, 16 Aug 2022 23:38:08 +0800
Subject: [PATCH] [Feature] Support EfficientFormer. (#954)
* add efficient backbone
* Update Readme and metafile
* Add unit tests
* fix confict
* fix lint
* update efficientformer head unit tests
* update README
* fix unit test
* fix Readme
* fix example
* fix typo
* recover api modification
* Update EfficiemtFormer Backbone
* fix unit tests
* add efficientformer to readme and model zoo
---
README.md | 1 +
configs/efficientformer/README.md | 47 ++
.../efficientformer-l1_8xb128_in1k.py | 24 +
.../efficientformer-l3_8xb128_in1k.py | 24 +
.../efficientformer-l7_8xb128_in1k.py | 24 +
configs/efficientformer/metafile.yml | 67 ++
docs/en/api/models.rst | 1 +
docs/en/model_zoo.md | 3 +
mmcls/models/backbones/__init__.py | 3 +-
mmcls/models/backbones/efficientformer.py | 637 ++++++++++++++++++
mmcls/models/heads/__init__.py | 3 +-
mmcls/models/heads/efficientformer_head.py | 96 +++
model-index.yml | 1 +
.../test_backbones/test_efficientformer.py | 241 +++++++
tests/test_models/test_heads.py | 59 +-
15 files changed, 1228 insertions(+), 3 deletions(-)
create mode 100644 configs/efficientformer/README.md
create mode 100644 configs/efficientformer/efficientformer-l1_8xb128_in1k.py
create mode 100644 configs/efficientformer/efficientformer-l3_8xb128_in1k.py
create mode 100644 configs/efficientformer/efficientformer-l7_8xb128_in1k.py
create mode 100644 configs/efficientformer/metafile.yml
create mode 100644 mmcls/models/backbones/efficientformer.py
create mode 100644 mmcls/models/heads/efficientformer_head.py
create mode 100644 tests/test_models/test_backbones/test_efficientformer.py
diff --git a/README.md b/README.md
index 09e227339c1..eec6036b944 100644
--- a/README.md
+++ b/README.md
@@ -143,6 +143,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [CSPNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/cspnet)
- [x] [PoolFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/poolformer)
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/master/configs/mvit)
+- [x] [EfficientFormer](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientformer)
diff --git a/configs/efficientformer/README.md b/configs/efficientformer/README.md
new file mode 100644
index 00000000000..ecd6b4927e5
--- /dev/null
+++ b/configs/efficientformer/README.md
@@ -0,0 +1,47 @@
+# EfficientFormer
+
+> [EfficientFormer: Vision Transformers at MobileNet Speed](https://arxiv.org/abs/2206.01191)
+
+
+
+## Abstract
+
+Vision Transformers (ViT) have shown rapid progress in computer vision tasks, achieving promising results on various benchmarks. However, due to the massive number of parameters and model design, e.g., attention mechanism, ViT-based models are generally times slower than lightweight convolutional networks. Therefore, the deployment of ViT for real-time applications is particularly challenging, especially on resource-constrained hardware such as mobile devices. Recent efforts try to reduce the computation complexity of ViT through network architecture search or hybrid design with MobileNet block, yet the inference speed is still unsatisfactory. This leads to an important question: can transformers run as fast as MobileNet while obtaining high performance? To answer this, we first revisit the network architecture and operators used in ViT-based models and identify inefficient designs. Then we introduce a dimension-consistent pure transformer (without MobileNet blocks) as a design paradigm. Finally, we perform latency-driven slimming to get a series of final models dubbed EfficientFormer. Extensive experiments show the superiority of EfficientFormer in performance and speed on mobile devices. Our fastest model, EfficientFormer-L1, achieves 79.2% top-1 accuracy on ImageNet-1K with only 1.6 ms inference latency on iPhone 12 (compiled with CoreML), which runs as fast as MobileNetV2×1.4 (1.6 ms, 74.7% top-1), and our largest model, EfficientFormer-L7, obtains 83.3% accuracy with only 7.0 ms latency. Our work proves that properly designed transformers can reach extremely low latency on mobile devices while maintaining high performance.
+
+
+
![](https://user-images.githubusercontent.com/18586273/180713426-9d3d77e3-3584-42d8-9098-625b4170d796.png)
+
+
+## Results and models
+
+### ImageNet-1k
+
+| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
+| :------------------: | :-------: | :------: | :-------: | :-------: | :---------------------------------------------------------------------: | :------------------------------------------------------------------------: |
+| EfficientFormer-l1\* | 12.19 | 1.30 | 80.46 | 94.99 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l1_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l1_3rdparty_in1k_20220803-d66e61df.pth) |
+| EfficientFormer-l3\* | 31.41 | 3.93 | 82.45 | 96.18 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l3_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l3_3rdparty_in1k_20220803-dde1c8c5.pth) |
+| EfficientFormer-l7\* | 82.23 | 10.16 | 83.40 | 96.60 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l7_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l7_3rdparty_in1k_20220803-41a552bb.pth) |
+
+*Models with * are converted from the [official repo](https://github.com/snap-research/EfficientFormer). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*
+
+## Citation
+
+```bibtex
+@misc{https://doi.org/10.48550/arxiv.2206.01191,
+ doi = {10.48550/ARXIV.2206.01191},
+
+ url = {https://arxiv.org/abs/2206.01191},
+
+ author = {Li, Yanyu and Yuan, Geng and Wen, Yang and Hu, Eric and Evangelidis, Georgios and Tulyakov, Sergey and Wang, Yanzhi and Ren, Jian},
+
+ keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
+
+ title = {EfficientFormer: Vision Transformers at MobileNet Speed},
+
+ publisher = {arXiv},
+
+ year = {2022},
+
+ copyright = {Creative Commons Attribution 4.0 International}
+}
+```
diff --git a/configs/efficientformer/efficientformer-l1_8xb128_in1k.py b/configs/efficientformer/efficientformer-l1_8xb128_in1k.py
new file mode 100644
index 00000000000..f5db2bfc63b
--- /dev/null
+++ b/configs/efficientformer/efficientformer-l1_8xb128_in1k.py
@@ -0,0 +1,24 @@
+_base_ = [
+ '../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='EfficientFormer',
+ arch='l1',
+ drop_path_rate=0,
+ init_cfg=[
+ dict(
+ type='TruncNormal',
+ layer=['Conv2d', 'Linear'],
+ std=.02,
+ bias=0.),
+ dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-5)
+ ]),
+ neck=dict(type='GlobalAveragePooling', dim=1),
+ head=dict(
+ type='EfficientFormerClsHead', in_channels=448, num_classes=1000))
diff --git a/configs/efficientformer/efficientformer-l3_8xb128_in1k.py b/configs/efficientformer/efficientformer-l3_8xb128_in1k.py
new file mode 100644
index 00000000000..e920f785d84
--- /dev/null
+++ b/configs/efficientformer/efficientformer-l3_8xb128_in1k.py
@@ -0,0 +1,24 @@
+_base_ = [
+ '../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='EfficientFormer',
+ arch='l3',
+ drop_path_rate=0,
+ init_cfg=[
+ dict(
+ type='TruncNormal',
+ layer=['Conv2d', 'Linear'],
+ std=.02,
+ bias=0.),
+ dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-5)
+ ]),
+ neck=dict(type='GlobalAveragePooling', dim=1),
+ head=dict(
+ type='EfficientFormerClsHead', in_channels=512, num_classes=1000))
diff --git a/configs/efficientformer/efficientformer-l7_8xb128_in1k.py b/configs/efficientformer/efficientformer-l7_8xb128_in1k.py
new file mode 100644
index 00000000000..a59e3a7ed5a
--- /dev/null
+++ b/configs/efficientformer/efficientformer-l7_8xb128_in1k.py
@@ -0,0 +1,24 @@
+_base_ = [
+ '../_base_/datasets/imagenet_bs128_poolformer_small_224.py',
+ '../_base_/schedules/imagenet_bs1024_adamw_swin.py',
+ '../_base_/default_runtime.py',
+]
+
+model = dict(
+ type='ImageClassifier',
+ backbone=dict(
+ type='EfficientFormer',
+ arch='l7',
+ drop_path_rate=0,
+ init_cfg=[
+ dict(
+ type='TruncNormal',
+ layer=['Conv2d', 'Linear'],
+ std=.02,
+ bias=0.),
+ dict(type='Constant', layer=['GroupNorm'], val=1., bias=0.),
+ dict(type='Constant', layer=['LayerScale'], val=1e-5)
+ ]),
+ neck=dict(type='GlobalAveragePooling', dim=1),
+ head=dict(
+ type='EfficientFormerClsHead', in_channels=768, num_classes=1000))
diff --git a/configs/efficientformer/metafile.yml b/configs/efficientformer/metafile.yml
new file mode 100644
index 00000000000..33c47865e9f
--- /dev/null
+++ b/configs/efficientformer/metafile.yml
@@ -0,0 +1,67 @@
+Collections:
+ - Name: EfficientFormer
+ Metadata:
+ Training Data: ImageNet-1k
+ Architecture:
+ - Pooling
+ - 1x1 Convolution
+ - LayerScale
+ - MetaFormer
+ Paper:
+ URL: https://arxiv.org/pdf/2206.01191.pdf
+ Title: "EfficientFormer: Vision Transformers at MobileNet Speed"
+ README: configs/efficientformer/README.md
+ Code:
+ Version: v0.24.0
+ URL: https://github.com/open-mmlab/mmclassification/blob/v0.24.0/mmcls/models/backbones/efficientformer.py
+
+Models:
+ - Name: efficientformer-l1_3rdparty_8xb128_in1k
+ Metadata:
+ FLOPs: 1304601088 # 1.3G
+ Parameters: 12278696 # 12M
+ In Collections: EfficientFormer
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 80.46
+ Top 5 Accuracy: 94.99
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l1_3rdparty_in1k_20220803-d66e61df.pth
+ Config: configs/efficientformer/efficientformer-l1_8xb128_in1k.py
+ Converted From:
+ Weights: https://drive.google.com/file/d/11SbX-3cfqTOc247xKYubrAjBiUmr818y/view?usp=sharing
+ Code: https://github.com/snap-research/EfficientFormer
+ - Name: efficientformer-l3_3rdparty_8xb128_in1k
+ Metadata:
+ Training Data: ImageNet-1k
+ FLOPs: 3737045760 # 3.7G
+ Parameters: 31406000 # 31M
+ In Collections: EfficientFormer
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 82.45
+ Top 5 Accuracy: 96.18
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l3_3rdparty_in1k_20220803-dde1c8c5.pth
+ Config: configs/efficientformer/efficientformer-l3_8xb128_in1k.py
+ Converted From:
+ Weights: https://drive.google.com/file/d/1OyyjKKxDyMj-BcfInp4GlDdwLu3hc30m/view?usp=sharing
+ Code: https://github.com/snap-research/EfficientFormer
+ - Name: efficientformer-l7_3rdparty_8xb128_in1k
+ Metadata:
+ FLOPs: 10163951616 # 10.2G
+ Parameters: 82229328 # 82M
+ In Collections: EfficientFormer
+ Results:
+ - Dataset: ImageNet-1k
+ Metrics:
+ Top 1 Accuracy: 83.40
+ Top 5 Accuracy: 96.60
+ Task: Image Classification
+ Weights: https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l7_3rdparty_in1k_20220803-41a552bb.pth
+ Config: configs/efficientformer/efficientformer-l7_8xb128_in1k.py
+ Converted From:
+ Weights: https://drive.google.com/file/d/1cVw-pctJwgvGafeouynqWWCwgkcoFMM5/view?usp=sharing
+ Code: https://github.com/snap-research/EfficientFormer
diff --git a/docs/en/api/models.rst b/docs/en/api/models.rst
index 78701a41fe1..37938e34d95 100644
--- a/docs/en/api/models.rst
+++ b/docs/en/api/models.rst
@@ -87,6 +87,7 @@ Backbones
VAN
VGG
VisionTransformer
+ EfficientFormer
.. _necks:
diff --git a/docs/en/model_zoo.md b/docs/en/model_zoo.md
index 7a9a750b86b..46b42a97e68 100644
--- a/docs/en/model_zoo.md
+++ b/docs/en/model_zoo.md
@@ -145,6 +145,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| MViTv2-small\* | 34.87 | 7.00 | 83.63 | 96.51 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mvit/mvitv2-small_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mvit/mvitv2-small_3rdparty_in1k_20220722-986bd741.pth) |
| MViTv2-base\* | 51.47 | 10.20 | 84.34 | 96.86 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mvit/mvitv2-base_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mvit/mvitv2-base_3rdparty_in1k_20220722-9c4f0a17.pth) |
| MViTv2-large\* | 217.99 | 42.10 | 85.25 | 97.14 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mvit/mvitv2-large_8xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mvit/mvitv2-large_3rdparty_in1k_20220722-2b57b983.pth) |
+| EfficientFormer-l1\* | 12.19 | 1.30 | 80.46 | 94.99 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l1_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l1_3rdparty_in1k_20220803-d66e61df.pth) |
+| EfficientFormer-l3\* | 31.41 | 3.93 | 82.45 | 96.18 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l3_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l3_3rdparty_in1k_20220803-dde1c8c5.pth) |
+| EfficientFormer-l7\* | 82.23 | 10.16 | 83.40 | 96.60 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientformer/efficientformer-l7_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientformer/efficientformer-l7_3rdparty_in1k_20220803-41a552bb.pth) |
*Models with * are converted from other repos, others are trained by ourselves.*
diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py
index f1e772dec11..ad7b8189943 100644
--- a/mmcls/models/backbones/__init__.py
+++ b/mmcls/models/backbones/__init__.py
@@ -6,6 +6,7 @@
from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt
from .deit import DistilledVisionTransformer
from .densenet import DenseNet
+from .efficientformer import EfficientFormer
from .efficientnet import EfficientNet
from .hrnet import HRNet
from .lenet import LeNet5
@@ -44,5 +45,5 @@
'Res2Net', 'RepVGG', 'Conformer', 'MlpMixer', 'DistilledVisionTransformer',
'PCPVT', 'SVT', 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c',
'ConvMixer', 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet',
- 'RepMLPNet', 'PoolFormer', 'DenseNet', 'VAN', 'MViT'
+ 'RepMLPNet', 'PoolFormer', 'DenseNet', 'VAN', 'MViT', 'EfficientFormer'
]
diff --git a/mmcls/models/backbones/efficientformer.py b/mmcls/models/backbones/efficientformer.py
new file mode 100644
index 00000000000..fa3b14eb6e0
--- /dev/null
+++ b/mmcls/models/backbones/efficientformer.py
@@ -0,0 +1,637 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import itertools
+from typing import Optional, Sequence
+
+import torch
+import torch.nn as nn
+from mmcv.cnn.bricks import (ConvModule, DropPath, build_activation_layer,
+ build_norm_layer)
+from mmcv.runner import BaseModule, ModuleList, Sequential
+
+from ..builder import BACKBONES
+from .base_backbone import BaseBackbone
+from .poolformer import Pooling
+
+
+class AttentionWithBias(BaseModule):
+ """Multi-head Attention Module with attention_bias.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads. Defaults to 8.
+ key_dim (int): The dimension of q, k. Defaults to 32.
+ attn_ratio (float): The dimension of v equals to
+ ``key_dim * attn_ratio``. Defaults to 4.
+ resolution (int): The height and width of attention_bias.
+ Defaults to 7.
+ init_cfg (dict, optional): The Config for initialization.
+ Defaults to None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads=8,
+ key_dim=32,
+ attn_ratio=4.,
+ resolution=7,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.num_heads = num_heads
+ self.scale = key_dim**-0.5
+ self.attn_ratio = attn_ratio
+ self.key_dim = key_dim
+ self.nh_kd = key_dim * num_heads
+ self.d = int(attn_ratio * key_dim)
+ self.dh = int(attn_ratio * key_dim) * num_heads
+ h = self.dh + self.nh_kd * 2
+ self.qkv = nn.Linear(embed_dims, h)
+ self.proj = nn.Linear(self.dh, embed_dims)
+
+ points = list(itertools.product(range(resolution), range(resolution)))
+ N = len(points)
+ attention_offsets = {}
+ idxs = []
+ for p1 in points:
+ for p2 in points:
+ offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
+ if offset not in attention_offsets:
+ attention_offsets[offset] = len(attention_offsets)
+ idxs.append(attention_offsets[offset])
+ self.attention_biases = nn.Parameter(
+ torch.zeros(num_heads, len(attention_offsets)))
+ self.register_buffer('attention_bias_idxs',
+ torch.LongTensor(idxs).view(N, N))
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ """change the mode of model."""
+ super().train(mode)
+ if mode and hasattr(self, 'ab'):
+ del self.ab
+ else:
+ self.ab = self.attention_biases[:, self.attention_bias_idxs]
+
+ def forward(self, x):
+ """forward function.
+
+ Args:
+ x (tensor): input features with shape of (B, N, C)
+ """
+ B, N, _ = x.shape
+ qkv = self.qkv(x)
+ qkv = qkv.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
+ q, k, v = qkv.split([self.key_dim, self.key_dim, self.d], dim=-1)
+
+ attn = ((q @ k.transpose(-2, -1)) * self.scale +
+ (self.attention_biases[:, self.attention_bias_idxs]
+ if self.training else self.ab))
+ attn = attn.softmax(dim=-1)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
+ x = self.proj(x)
+ return x
+
+
+class Flat(nn.Module):
+ """Flat the input from (B, C, H, W) to (B, H*W, C)."""
+
+ def __init__(self, ):
+ super().__init__()
+
+ def forward(self, x: torch.Tensor):
+ x = x.flatten(2).transpose(1, 2)
+ return x
+
+
+class LinearMlp(BaseModule):
+ """Mlp implemented with linear.
+
+ The shape of input and output tensor are (B, N, C).
+
+ Args:
+ in_features (int): Dimension of input features.
+ hidden_features (int): Dimension of hidden features.
+ out_features (int): Dimension of output features.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ drop (float): Dropout rate. Defaults to 0.0.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_cfg=dict(type='GELU'),
+ drop=0.,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = build_activation_layer(act_cfg)
+ self.drop1 = nn.Dropout(drop)
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop2 = nn.Dropout(drop)
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): input tensor with shape (B, N, C).
+
+ Returns:
+ torch.Tensor: output tensor with shape (B, N, C).
+ """
+ x = self.drop1(self.act(self.fc1(x)))
+ x = self.drop2(self.fc2(x))
+ return x
+
+
+class ConvMlp(BaseModule):
+ """Mlp implemented with 1*1 convolutions.
+
+ Args:
+ in_features (int): Dimension of input features.
+ hidden_features (int): Dimension of hidden features.
+ out_features (int): Dimension of output features.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults to ``dict(type='BN')``.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ drop (float): Dropout rate. Defaults to 0.0.
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='GELU'),
+ drop=0.,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1]
+ self.norm2 = build_norm_layer(norm_cfg, out_features)[1]
+
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ """
+ Args:
+ x (torch.Tensor): input tensor with shape (B, C, H, W).
+
+ Returns:
+ torch.Tensor: output tensor with shape (B, C, H, W).
+ """
+
+ x = self.act(self.norm1(self.fc1(x)))
+ x = self.drop(x)
+ x = self.norm2(self.fc2(x))
+ x = self.drop(x)
+ return x
+
+
+class LayerScale(nn.Module):
+ """LayerScale layer.
+
+ Args:
+ dim (int): Dimension of input features.
+ inplace (bool): inplace: can optionally do the
+ operation in-place. Default: ``False``
+ data_format (str): The input data format, can be 'channels_last'
+ and 'channels_first', representing (B, C, H, W) and
+ (B, N, C) format data respectively.
+ """
+
+ def __init__(self,
+ dim: int,
+ inplace: bool = False,
+ data_format: str = 'channels_last'):
+ super().__init__()
+ assert data_format in ('channels_last', 'channels_first'), \
+ "'data_format' could only be channels_last or channels_first."
+ self.inplace = inplace
+ self.data_format = data_format
+ self.weight = nn.Parameter(torch.ones(dim) * 1e-5)
+
+ def forward(self, x):
+ if self.data_format == 'channels_first':
+ if self.inplace:
+ return x.mul_(self.weight.view(-1, 1, 1))
+ else:
+ return x * self.weight.view(-1, 1, 1)
+ return x.mul_(self.weight) if self.inplace else x * self.weight
+
+
+class Meta3D(BaseModule):
+ """Meta Former block using 3 dimensions inputs, ``torch.Tensor`` with shape
+ (B, N, C)."""
+
+ def __init__(self,
+ dim,
+ mlp_ratio=4.,
+ norm_cfg=dict(type='LN'),
+ act_cfg=dict(type='GELU'),
+ drop=0.,
+ drop_path=0.,
+ use_layer_scale=True,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.norm1 = build_norm_layer(norm_cfg, dim)[1]
+ self.token_mixer = AttentionWithBias(dim)
+ self.norm2 = build_norm_layer(norm_cfg, dim)[1]
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = LinearMlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. \
+ else nn.Identity()
+ if use_layer_scale:
+ self.ls1 = LayerScale(dim)
+ self.ls2 = LayerScale(dim)
+ else:
+ self.ls1, self.ls2 = nn.Identity(), nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path(self.ls1(self.token_mixer(self.norm1(x))))
+ x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class Meta4D(BaseModule):
+ """Meta Former block using 4 dimensions inputs, ``torch.Tensor`` with shape
+ (B, C, H, W)."""
+
+ def __init__(self,
+ dim,
+ pool_size=3,
+ mlp_ratio=4.,
+ act_cfg=dict(type='GELU'),
+ drop=0.,
+ drop_path=0.,
+ use_layer_scale=True,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ self.token_mixer = Pooling(pool_size=pool_size)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ConvMlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_cfg=act_cfg,
+ drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. \
+ else nn.Identity()
+ if use_layer_scale:
+ self.ls1 = LayerScale(dim, data_format='channels_first')
+ self.ls2 = LayerScale(dim, data_format='channels_first')
+ else:
+ self.ls1, self.ls2 = nn.Identity(), nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path(self.ls1(self.token_mixer(x)))
+ x = x + self.drop_path(self.ls2(self.mlp(x)))
+ return x
+
+
+def basic_blocks(in_channels,
+ out_channels,
+ index,
+ layers,
+ pool_size=3,
+ mlp_ratio=4.,
+ act_cfg=dict(type='GELU'),
+ drop_rate=.0,
+ drop_path_rate=0.,
+ use_layer_scale=True,
+ vit_num=1,
+ has_downsamper=False):
+ """generate EfficientFormer blocks for a stage."""
+ blocks = []
+ if has_downsamper:
+ blocks.append(
+ ConvModule(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=True,
+ norm_cfg=dict(type='BN'),
+ act_cfg=None))
+ if index == 3 and vit_num == layers[index]:
+ blocks.append(Flat())
+ for block_idx in range(layers[index]):
+ block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (
+ sum(layers) - 1)
+ if index == 3 and layers[index] - block_idx <= vit_num:
+ blocks.append(
+ Meta3D(
+ out_channels,
+ mlp_ratio=mlp_ratio,
+ act_cfg=act_cfg,
+ drop=drop_rate,
+ drop_path=block_dpr,
+ use_layer_scale=use_layer_scale,
+ ))
+ else:
+ blocks.append(
+ Meta4D(
+ out_channels,
+ pool_size=pool_size,
+ act_cfg=act_cfg,
+ drop=drop_rate,
+ drop_path=block_dpr,
+ use_layer_scale=use_layer_scale))
+ if index == 3 and layers[index] - block_idx - 1 == vit_num:
+ blocks.append(Flat())
+ blocks = nn.Sequential(*blocks)
+ return blocks
+
+
+@BACKBONES.register_module()
+class EfficientFormer(BaseBackbone):
+ """EfficientFormer.
+
+ A PyTorch implementation of EfficientFormer introduced by:
+ `EfficientFormer: Vision Transformers at MobileNet Speed `_
+
+ Modified from the `official repo
+ `.
+
+ Args:
+ arch (str | dict): The model's architecture. If string, it should be
+ one of architecture in ``EfficientFormer.arch_settings``. And if dict,
+ it should include the following 4 keys:
+
+ - layers (list[int]): Number of blocks at each stage.
+ - embed_dims (list[int]): The number of channels at each stage.
+ - downsamples (list[int]): Has downsample or not in the four stages.
+ - vit_num (int): The num of vit blocks in the last stage.
+
+ Defaults to 'l1'.
+
+ in_channels (int): The num of input channels. Defaults to 3.
+ pool_size (int): The pooling size of ``Meta4D`` blocks. Defaults to 3.
+ mlp_ratios (int): The dimension ratio of multi-head attention mechanism
+ in ``Meta4D`` blocks. Defaults to 3.
+ reshape_last_feat (bool): Whether to reshape the feature map from
+ (B, N, C) to (B, C, H, W) in the last stage, when the ``vit-num``
+ in ``arch`` is not 0. Defaults to False. Usually set to True
+ in downstream tasks.
+ out_indices (Sequence[int]): Output from which stages.
+ Defaults to -1.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Defaults to -1.
+ act_cfg (dict): The config dict for activation between pointwise
+ convolution. Defaults to ``dict(type='GELU')``.
+ drop_rate (float): Dropout rate. Defaults to 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults to 0.
+ use_layer_scale (bool): Whether to use use_layer_scale in MetaFormer
+ block. Defaults to True.
+ init_cfg (dict, optional): Initialization config dict.
+ Defaults to None.
+
+ Example:
+ >>> from mmcls.models import EfficientFormer
+ >>> import torch
+ >>> inputs = torch.rand((1, 3, 224, 224))
+ >>> # build EfficientFormer backbone for classification task
+ >>> model = EfficientFormer(arch="l1")
+ >>> model.eval()
+ >>> level_outputs = model(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 448, 49)
+ >>> # build EfficientFormer backbone for downstream task
+ >>> model = EfficientFormer(
+ >>> arch="l3",
+ >>> out_indices=(0, 1, 2, 3),
+ >>> reshape_last_feat=True)
+ >>> model.eval()
+ >>> level_outputs = model(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 56, 56)
+ (1, 128, 28, 28)
+ (1, 320, 14, 14)
+ (1, 512, 7, 7)
+ """ # noqa: E501
+
+ # --layers: [x,x,x,x], numbers of layers for the four stages
+ # --embed_dims: [x,x,x,x], embedding dims for the four stages
+ # --downsamples: [x,x,x,x], has downsample or not in the four stages
+ # --vit_num:(int), the num of vit blocks in the last stage
+ arch_settings = {
+ 'l1': {
+ 'layers': [3, 2, 6, 4],
+ 'embed_dims': [48, 96, 224, 448],
+ 'downsamples': [False, True, True, True],
+ 'vit_num': 1,
+ },
+ 'l3': {
+ 'layers': [4, 4, 12, 6],
+ 'embed_dims': [64, 128, 320, 512],
+ 'downsamples': [False, True, True, True],
+ 'vit_num': 4,
+ },
+ 'l7': {
+ 'layers': [6, 6, 18, 8],
+ 'embed_dims': [96, 192, 384, 768],
+ 'downsamples': [False, True, True, True],
+ 'vit_num': 8,
+ },
+ }
+
+ def __init__(self,
+ arch='l1',
+ in_channels=3,
+ pool_size=3,
+ mlp_ratios=4,
+ reshape_last_feat=False,
+ out_indices=-1,
+ frozen_stages=-1,
+ act_cfg=dict(type='GELU'),
+ drop_rate=0.,
+ drop_path_rate=0.,
+ use_layer_scale=True,
+ init_cfg=None):
+
+ super().__init__(init_cfg=init_cfg)
+ self.num_extra_tokens = 0 # no cls_token, no dist_token
+
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'Unavailable arch, please choose from ' \
+ f'({set(self.arch_settings)}) or pass a dict.'
+ arch = self.arch_settings[arch]
+ elif isinstance(arch, dict):
+ default_keys = set(self.arch_settings['l1'].keys())
+ assert set(arch.keys()) == default_keys, \
+ f'The arch dict must have {default_keys}, ' \
+ f'but got {list(arch.keys())}.'
+
+ self.layers = arch['layers']
+ self.embed_dims = arch['embed_dims']
+ self.downsamples = arch['downsamples']
+ assert isinstance(self.layers, list) and isinstance(
+ self.embed_dims, list) and isinstance(self.downsamples, list)
+ assert len(self.layers) == len(self.embed_dims) == len(
+ self.downsamples)
+
+ self.vit_num = arch['vit_num']
+ self.reshape_last_feat = reshape_last_feat
+
+ assert self.vit_num >= 0, "'vit_num' must be an integer " \
+ 'greater than or equal to 0.'
+ assert self.vit_num <= self.layers[-1], (
+ "'vit_num' must be an integer smaller than layer number")
+
+ self._make_stem(in_channels, self.embed_dims[0])
+
+ # set the main block in network
+ network = []
+ for i in range(len(self.layers)):
+ if i != 0:
+ in_channels = self.embed_dims[i - 1]
+ else:
+ in_channels = self.embed_dims[i]
+ out_channels = self.embed_dims[i]
+ stage = basic_blocks(
+ in_channels,
+ out_channels,
+ i,
+ self.layers,
+ pool_size=pool_size,
+ mlp_ratio=mlp_ratios,
+ act_cfg=act_cfg,
+ drop_rate=drop_rate,
+ drop_path_rate=drop_path_rate,
+ vit_num=self.vit_num,
+ use_layer_scale=use_layer_scale,
+ has_downsamper=self.downsamples[i])
+ network.append(stage)
+
+ self.network = ModuleList(network)
+
+ if isinstance(out_indices, int):
+ out_indices = [out_indices]
+ assert isinstance(out_indices, Sequence), \
+ f'"out_indices" must by a sequence or int, ' \
+ f'get {type(out_indices)} instead.'
+ for i, index in enumerate(out_indices):
+ if index < 0:
+ out_indices[i] = 4 + index
+ assert out_indices[i] >= 0, f'Invalid out_indices {index}'
+
+ self.out_indices = out_indices
+ for i_layer in self.out_indices:
+ if not self.reshape_last_feat and \
+ i_layer == 3 and self.vit_num > 0:
+ layer = build_norm_layer(
+ dict(type='LN'), self.embed_dims[i_layer])[1]
+ else:
+ # use GN with 1 group as channel-first LN2D
+ layer = build_norm_layer(
+ dict(type='GN', num_groups=1), self.embed_dims[i_layer])[1]
+
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self.frozen_stages = frozen_stages
+ self._freeze_stages()
+
+ def _make_stem(self, in_channels: int, stem_channels: int):
+ """make 2-ConvBNReLu stem layer."""
+ self.patch_embed = Sequential(
+ ConvModule(
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ inplace=True),
+ ConvModule(
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ inplace=True))
+
+ def forward_tokens(self, x):
+ outs = []
+ for idx, block in enumerate(self.network):
+ if idx == len(self.network) - 1:
+ N, _, H, W = x.shape
+ if self.downsamples[idx]:
+ H, W = H // 2, W // 2
+ x = block(x)
+ if idx in self.out_indices:
+ norm_layer = getattr(self, f'norm{idx}')
+
+ if idx == len(self.network) - 1 and x.dim() == 3:
+ # when ``vit-num`` > 0 and in the last stage,
+ # if `self.reshape_last_feat`` is True, reshape the
+ # features to `BCHW` format before the final normalization.
+ # if `self.reshape_last_feat`` is False, do
+ # normalization directly and permute the features to `BCN`.
+ if self.reshape_last_feat:
+ x = x.permute((0, 2, 1)).reshape(N, -1, H, W)
+ x_out = norm_layer(x)
+ else:
+ x_out = norm_layer(x).permute((0, 2, 1))
+ else:
+ x_out = norm_layer(x)
+
+ outs.append(x_out.contiguous())
+ return tuple(outs)
+
+ def forward(self, x):
+ # input embedding
+ x = self.patch_embed(x)
+ # through stages
+ x = self.forward_tokens(x)
+ return x
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ for i in range(self.frozen_stages):
+ # Include both block and downsample layer.
+ module = self.network[i]
+ module.eval()
+ for param in module.parameters():
+ param.requires_grad = False
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ norm_layer.eval()
+ for param in norm_layer.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(EfficientFormer, self).train(mode)
+ self._freeze_stages()
diff --git a/mmcls/models/heads/__init__.py b/mmcls/models/heads/__init__.py
index ad520ce4ca5..d7301613094 100644
--- a/mmcls/models/heads/__init__.py
+++ b/mmcls/models/heads/__init__.py
@@ -2,6 +2,7 @@
from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead
+from .efficientformer_head import EfficientFormerClsHead
from .linear_head import LinearClsHead
from .multi_label_csra_head import CSRAClsHead
from .multi_label_head import MultiLabelClsHead
@@ -12,5 +13,5 @@
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
- 'ConformerHead', 'CSRAClsHead'
+ 'ConformerHead', 'EfficientFormerClsHead', 'CSRAClsHead'
]
diff --git a/mmcls/models/heads/efficientformer_head.py b/mmcls/models/heads/efficientformer_head.py
new file mode 100644
index 00000000000..3127f12e371
--- /dev/null
+++ b/mmcls/models/heads/efficientformer_head.py
@@ -0,0 +1,96 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import HEADS
+from .cls_head import ClsHead
+
+
+@HEADS.register_module()
+class EfficientFormerClsHead(ClsHead):
+ """EfficientFormer classifier head.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ distillation (bool): Whether use a additional distilled head.
+ Defaults to True.
+ init_cfg (dict): The extra initialization configs. Defaults to
+ ``dict(type='Normal', layer='Linear', std=0.01)``.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ distillation=True,
+ init_cfg=dict(type='Normal', layer='Linear', std=0.01),
+ *args,
+ **kwargs):
+ super(EfficientFormerClsHead, self).__init__(
+ init_cfg=init_cfg, *args, **kwargs)
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.dist = distillation
+
+ if self.num_classes <= 0:
+ raise ValueError(
+ f'num_classes={num_classes} must be a positive integer')
+
+ self.head = nn.Linear(self.in_channels, self.num_classes)
+ if self.dist:
+ self.dist_head = nn.Linear(self.in_channels, self.num_classes)
+
+ def pre_logits(self, x):
+ if isinstance(x, tuple):
+ x = x[-1]
+ return x
+
+ def simple_test(self, x, softmax=True, post_process=True):
+ """Inference without augmentation.
+
+ Args:
+ x (tuple[tuple[tensor, tensor]]): The input features.
+ Multi-stage inputs are acceptable but only the last stage will
+ be used to classify. Every item should be a tuple which
+ includes patch token and cls token. The cls token will be used
+ to classify and the shape of it should be
+ ``(num_samples, in_channels)``.
+ softmax (bool): Whether to softmax the classification score.
+ post_process (bool): Whether to do post processing the
+ inference results. It will convert the output to a list.
+
+ Returns:
+ Tensor | list: The inference results.
+
+ - If no post processing, the output is a tensor with shape
+ ``(num_samples, num_classes)``.
+ - If post processing, the output is a multi-dimentional list of
+ float and the dimensions are ``(num_samples, num_classes)``.
+ """
+ x = self.pre_logits(x)
+ cls_score = self.head(x)
+ if self.dist:
+ cls_score = (cls_score + self.dist_head(x)) / 2
+
+ if softmax:
+ pred = (
+ F.softmax(cls_score, dim=1) if cls_score is not None else None)
+ else:
+ pred = cls_score
+
+ if post_process:
+ return self.post_process(pred)
+ else:
+ return pred
+
+ def forward_train(self, x, gt_label, **kwargs):
+ if self.dist:
+ raise NotImplementedError(
+ "MMClassification doesn't support to train"
+ ' the distilled version EfficientFormer.')
+ else:
+ x = self.pre_logits(x)
+ cls_score = self.head(x)
+ losses = self.loss(cls_score, gt_label, **kwargs)
+ return losses
diff --git a/model-index.yml b/model-index.yml
index a57802a85f0..a48ab85a4cc 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -30,3 +30,4 @@ Import:
- configs/poolformer/metafile.yml
- configs/csra/metafile.yml
- configs/mvit/metafile.yml
+ - configs/efficientformer/metafile.yml
diff --git a/tests/test_models/test_backbones/test_efficientformer.py b/tests/test_models/test_backbones/test_efficientformer.py
new file mode 100644
index 00000000000..01d9daea4b8
--- /dev/null
+++ b/tests/test_models/test_backbones/test_efficientformer.py
@@ -0,0 +1,241 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+from unittest import TestCase
+
+import torch
+from mmcv.cnn import ConvModule
+from torch import nn
+
+from mmcls.models.backbones import EfficientFormer
+from mmcls.models.backbones.efficientformer import (AttentionWithBias, Flat,
+ LayerScale, Meta3D, Meta4D)
+from mmcls.models.backbones.poolformer import Pooling
+
+
+class TestLayerScale(TestCase):
+
+ def test_init(self):
+ with self.assertRaisesRegex(AssertionError, "'data_format' could"):
+ cfg = dict(
+ dim=10,
+ inplace=False,
+ data_format='BNC',
+ )
+ LayerScale(**cfg)
+
+ cfg = dict(dim=10)
+ ls = LayerScale(**cfg)
+ assert torch.equal(ls.weight,
+ torch.ones(10, requires_grad=True) * 1e-5)
+
+ def forward(self):
+ # Test channels_last
+ cfg = dict(dim=256, inplace=False, data_format='channels_last')
+ ls_channels_last = LayerScale(**cfg)
+ x = torch.randn((4, 49, 256))
+ out = ls_channels_last(x)
+ self.assertEqual(tuple(out.size()), (4, 49, 256))
+ assert torch.equal(x * 1e-5, out)
+
+ # Test channels_first
+ cfg = dict(dim=256, inplace=False, data_format='channels_first')
+ ls_channels_first = LayerScale(**cfg)
+ x = torch.randn((4, 256, 7, 7))
+ out = ls_channels_first(x)
+ self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
+ assert torch.equal(x * 1e-5, out)
+
+ # Test inplace True
+ cfg = dict(dim=256, inplace=True, data_format='channels_first')
+ ls_channels_first = LayerScale(**cfg)
+ x = torch.randn((4, 256, 7, 7))
+ out = ls_channels_first(x)
+ self.assertEqual(tuple(out.size()), (4, 256, 7, 7))
+ self.assertIs(x, out)
+
+
+class TestEfficientFormer(TestCase):
+
+ def setUp(self):
+ self.cfg = dict(arch='l1', drop_path_rate=0.1)
+ self.arch = EfficientFormer.arch_settings['l1']
+ self.custom_arch = {
+ 'layers': [1, 1, 1, 4],
+ 'embed_dims': [48, 96, 224, 448],
+ 'downsamples': [False, True, True, True],
+ 'vit_num': 2,
+ }
+ self.custom_cfg = dict(arch=self.custom_arch)
+
+ def test_arch(self):
+ # Test invalid default arch
+ with self.assertRaisesRegex(AssertionError, 'Unavailable arch'):
+ cfg = deepcopy(self.cfg)
+ cfg['arch'] = 'unknown'
+ EfficientFormer(**cfg)
+
+ # Test invalid custom arch
+ with self.assertRaisesRegex(AssertionError, 'must have'):
+ cfg = deepcopy(self.custom_cfg)
+ cfg['arch'].pop('layers')
+ EfficientFormer(**cfg)
+
+ # Test vit_num < 0
+ with self.assertRaisesRegex(AssertionError, "'vit_num' must"):
+ cfg = deepcopy(self.custom_cfg)
+ cfg['arch']['vit_num'] = -1
+ EfficientFormer(**cfg)
+
+ # Test vit_num > last stage layers
+ with self.assertRaisesRegex(AssertionError, "'vit_num' must"):
+ cfg = deepcopy(self.custom_cfg)
+ cfg['arch']['vit_num'] = 10
+ EfficientFormer(**cfg)
+
+ # Test out_ind
+ with self.assertRaisesRegex(AssertionError, '"out_indices" must'):
+ cfg = deepcopy(self.custom_cfg)
+ cfg['out_indices'] = dict
+ EfficientFormer(**cfg)
+
+ # Test custom arch
+ cfg = deepcopy(self.custom_cfg)
+ model = EfficientFormer(**cfg)
+ self.assertEqual(len(model.patch_embed), 2)
+ layers = self.custom_arch['layers']
+ downsamples = self.custom_arch['downsamples']
+ vit_num = self.custom_arch['vit_num']
+
+ for i, stage in enumerate(model.network):
+ if downsamples[i]:
+ self.assertIsInstance(stage[0], ConvModule)
+ self.assertEqual(stage[0].conv.stride, (2, 2))
+ self.assertTrue(hasattr(stage[0].conv, 'bias'))
+ self.assertTrue(isinstance(stage[0].bn, nn.BatchNorm2d))
+
+ if i < len(model.network) - 1:
+ self.assertIsInstance(stage[-1], Meta4D)
+ self.assertIsInstance(stage[-1].token_mixer, Pooling)
+ self.assertEqual(len(stage) - downsamples[i], layers[i])
+ elif vit_num > 0:
+ self.assertIsInstance(stage[-1], Meta3D)
+ self.assertIsInstance(stage[-1].token_mixer, AttentionWithBias)
+ self.assertEqual(len(stage) - downsamples[i] - 1, layers[i])
+ flat_layer_idx = len(stage) - vit_num - downsamples[i]
+ self.assertIsInstance(stage[flat_layer_idx], Flat)
+ count = 0
+ for layer in stage:
+ if isinstance(layer, Meta3D):
+ count += 1
+ self.assertEqual(count, vit_num)
+
+ def test_init_weights(self):
+ # test weight init cfg
+ cfg = deepcopy(self.cfg)
+ cfg['init_cfg'] = [
+ dict(
+ type='Kaiming',
+ layer='Conv2d',
+ mode='fan_in',
+ nonlinearity='linear'),
+ dict(type='Constant', layer=['LayerScale'], val=1e-4)
+ ]
+ model = EfficientFormer(**cfg)
+ ori_weight = model.patch_embed[0].conv.weight.clone().detach()
+ ori_ls_weight = model.network[0][-1].ls1.weight.clone().detach()
+
+ model.init_weights()
+ initialized_weight = model.patch_embed[0].conv.weight
+ initialized_ls_weight = model.network[0][-1].ls1.weight
+ self.assertFalse(torch.allclose(ori_weight, initialized_weight))
+ self.assertFalse(torch.allclose(ori_ls_weight, initialized_ls_weight))
+
+ def test_forward(self):
+ imgs = torch.randn(1, 3, 224, 224)
+
+ # test last stage output
+ cfg = deepcopy(self.cfg)
+ model = EfficientFormer(**cfg)
+ outs = model(imgs)
+ self.assertIsInstance(outs, tuple)
+ self.assertEqual(len(outs), 1)
+ feat = outs[-1]
+ self.assertEqual(feat.shape, (1, 448, 49))
+ assert hasattr(model, 'norm3')
+ assert isinstance(getattr(model, 'norm3'), nn.LayerNorm)
+
+ # test multiple output indices
+ cfg = deepcopy(self.cfg)
+ cfg['out_indices'] = (0, 1, 2, 3)
+ cfg['reshape_last_feat'] = True
+ model = EfficientFormer(**cfg)
+ outs = model(imgs)
+ self.assertIsInstance(outs, tuple)
+ self.assertEqual(len(outs), 4)
+ # Test out features shape
+ for dim, stride, out in zip(self.arch['embed_dims'], [1, 2, 4, 8],
+ outs):
+ self.assertEqual(out.shape, (1, dim, 56 // stride, 56 // stride))
+
+ # Test norm layer
+ for i in range(4):
+ assert hasattr(model, f'norm{i}')
+ stage_norm = getattr(model, f'norm{i}')
+ assert isinstance(stage_norm, nn.GroupNorm)
+ assert stage_norm.num_groups == 1
+
+ # Test vit_num == 0
+ cfg = deepcopy(self.custom_cfg)
+ cfg['arch']['vit_num'] = 0
+ cfg['out_indices'] = (0, 1, 2, 3)
+ model = EfficientFormer(**cfg)
+ for i in range(4):
+ assert hasattr(model, f'norm{i}')
+ stage_norm = getattr(model, f'norm{i}')
+ assert isinstance(stage_norm, nn.GroupNorm)
+ assert stage_norm.num_groups == 1
+
+ def test_structure(self):
+ # test drop_path_rate decay
+ cfg = deepcopy(self.cfg)
+ cfg['drop_path_rate'] = 0.2
+ model = EfficientFormer(**cfg)
+ layers = self.arch['layers']
+ for i, block in enumerate(model.network):
+ expect_prob = 0.2 / (sum(layers) - 1) * i
+ if hasattr(block, 'drop_path'):
+ if expect_prob == 0:
+ self.assertIsInstance(block.drop_path, torch.nn.Identity)
+ else:
+ self.assertAlmostEqual(block.drop_path.drop_prob,
+ expect_prob)
+
+ # test with first stage frozen.
+ cfg = deepcopy(self.cfg)
+ frozen_stages = 1
+ cfg['frozen_stages'] = frozen_stages
+ cfg['out_indices'] = (0, 1, 2, 3)
+ model = EfficientFormer(**cfg)
+ model.init_weights()
+ model.train()
+
+ # the patch_embed and first stage should not require grad.
+ self.assertFalse(model.patch_embed.training)
+ for param in model.patch_embed.parameters():
+ self.assertFalse(param.requires_grad)
+ for i in range(frozen_stages):
+ module = model.network[i]
+ for param in module.parameters():
+ self.assertFalse(param.requires_grad)
+ for param in model.norm0.parameters():
+ self.assertFalse(param.requires_grad)
+
+ # the second stage should require grad.
+ for i in range(frozen_stages + 1, 4):
+ module = model.network[i]
+ for param in module.parameters():
+ self.assertTrue(param.requires_grad)
+ if hasattr(model, f'norm{i}'):
+ norm = getattr(model, f'norm{i}')
+ for param in norm.parameters():
+ self.assertTrue(param.requires_grad)
diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py
index 4ab18c332c4..e0ecdb6b5c2 100644
--- a/tests/test_models/test_heads.py
+++ b/tests/test_models/test_heads.py
@@ -5,7 +5,8 @@
import torch
from mmcls.models.heads import (ClsHead, ConformerHead, CSRAClsHead,
- DeiTClsHead, LinearClsHead, MultiLabelClsHead,
+ DeiTClsHead, EfficientFormerClsHead,
+ LinearClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead, StackedLinearClsHead,
VisionTransformerClsHead)
@@ -319,6 +320,62 @@ def test_deit_head():
DeiTClsHead(-1, 100)
+def test_efficientformer_head():
+ fake_features = (torch.rand(4, 64), )
+ fake_gt_label = torch.randint(0, 10, (4, ))
+
+ # Test without distillation head
+ head = EfficientFormerClsHead(
+ num_classes=10, in_channels=64, distillation=False)
+
+ # test EfficientFormer head forward
+ losses = head.forward_train(fake_features, fake_gt_label)
+ assert losses['loss'].item() > 0
+
+ # test simple_test with post_process
+ pred = head.simple_test(fake_features)
+ assert isinstance(pred, list) and len(pred) == 4
+ with patch('torch.onnx.is_in_onnx_export', return_value=True):
+ pred = head.simple_test(fake_features)
+ assert pred.shape == (4, 10)
+
+ # test simple_test without post_process
+ pred = head.simple_test(fake_features, post_process=False)
+ assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
+ logits = head.simple_test(fake_features, softmax=False, post_process=False)
+ torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
+
+ # test pre_logits
+ features = head.pre_logits(fake_features)
+ assert features is fake_features[0]
+
+ # Test without distillation head
+ head = EfficientFormerClsHead(num_classes=10, in_channels=64)
+ assert hasattr(head, 'head')
+ assert hasattr(head, 'dist_head')
+
+ # Test loss
+ with pytest.raises(NotImplementedError):
+ losses = head.forward_train(fake_features, fake_gt_label)
+
+ # test simple_test with post_process
+ pred = head.simple_test(fake_features)
+ assert isinstance(pred, list) and len(pred) == 4
+ with patch('torch.onnx.is_in_onnx_export', return_value=True):
+ pred = head.simple_test(fake_features)
+ assert pred.shape == (4, 10)
+
+ # test simple_test without post_process
+ pred = head.simple_test(fake_features, post_process=False)
+ assert isinstance(pred, torch.Tensor) and pred.shape == (4, 10)
+ logits = head.simple_test(fake_features, softmax=False, post_process=False)
+ torch.testing.assert_allclose(pred, torch.softmax(logits, dim=1))
+
+ # test pre_logits
+ features = head.pre_logits(fake_features)
+ assert features is fake_features[0]
+
+
@pytest.mark.parametrize(
'feat', [torch.rand(4, 20, 20, 30), (torch.rand(4, 20, 20, 30), )])
def test_csra_head(feat):