Skip to content

Commit

Permalink
[Feature] Add RepVGG backbone and checkpoints. (#414)
Browse files Browse the repository at this point in the history
* Add RepVGG code.

* Add se_module as plugin.

* Add the repvggA0 primitive config

* Change repvggA0.py to fit mmcls

* Add RepVGG configs

* Add repvgg_to_mmcls

* Add tools/deployment/convert_repvggblock_param_to_deploy.py

* Change configs/repvgg/README.md

* Streamlining the number of configuration files.

* Fix lints

* Delete plugins

* Delete code about plugin.

* Modify the code for using se module.

* Modify config to fit repvgg with se.

* Change se_cfg to allow loading of pre-training parameters.

* Reduce the complexity of the configuration file.

* Finsh unitest for repvgg.

* Fix bug about se in repvgg_to_mmcls.

* Rename convert_repvggblock_param_to_deploy.py to reparameterize_repvgg.py, and delete setting about device.

* test commit

* test commit

* test commit command

* Modify repvgg.py to make the code more readable.

* Add value=0 in F.pad()

* Add se_cfg to arch_settings.

* Fix bug.

* modeify some attr name and Update unit tests

* rename stage_0 to stem and branch_identity to branch_norm

* update unit tests

* add m.eval in unit tests

* [Enhance] Enhence SE layer to support custom squeeze channels. (#417)

* add enhenced SE

* Update

* rm basechannel

* fix docstring

* Update se_layer.py

fix docstring

* [Docs] Add algorithm readme and update meta yml (#418)

* Add README.md for models without checkpoints.

* Update model-index.yml

* Update metafile.yml of seresnet

* [Enhance] Add `hparams` argument in `AutoAugment` and `RandAugment` and some other improvement. (#398)

* Add hparams argument in `AutoAugment` and `RandAugment`.

And `pad_val` supports sequence instead of tuple only.

* Add unit tests for `AutoAugment` and `hparams` in `RandAugment`.

* Use smaller test image to speed up uni tests.

* Use hparams to simplify RandAugment config in swin-transformer.

* Rename augment config name from `pipeline` to `pipelines`.

* Add some commnet ad docstring.

* [Feature] Support classwise weight in losses (#388)

* Add classwise weight in losses:CE,BCE,softBCE

* Update unit test

* rm some extra code

* rm some extra code

* fix broadcast

* fix broadcast

* update unit tests

* use new_tensor

* fix lint

* [Enhance] Better result visualization (#419)

* Imporve result visualization to support wait time and change the backend
to matplotlib.

* Add unit test for visualization

* Add adaptive dpi function

* Rename `imshow_cls_result` to `imshow_infos`.

* Support str in `imshow_infos`

* Improve docstring.

* Bump version to v0.15.0 (#426)

* [CI] Add PyTorch 1.9 and Python 3.9 build workflow, and remove some CI. (#422)

* Add PyTorch 1.9 build workflow, and remove some CI.

* Add Python 3.9 CI

* Show Python 3.9 support.

* [Enhance] Rename the option `--options` in some tools to `--cfg-options`. (#425)

* [Docs] Fix sphinx version (#429)

* [Docs] Add `CITATION.cff` (#428)

* Add CITATION.cff

* Fix typo in setup.py

* Change author in setup.py

* modeify some attr name and Update unit tests

* rename stage_0 to stem and branch_identity to branch_norm

* update unit tests

* add m.eval in unit tests

* Update unit tests

* refactor

* refactor

* Alignment inference accuracy

* Update configs, readme and metafile

* Update readme

* return tuple and fix metafile

* fix unit test

* rm regnet and classifiers changes

* update auto_aug

* update metafile & readme

* use delattr

* rename cfgs

* Update checkpoint url

* Update readme

* Rename config files.

* Update readme and metafile

* add comment

* Update mmcls/models/backbones/repvgg.py

Co-authored-by: Ma Zerun <mzr1996@163.com>

* Update docstring

* Improve docstring.

* Update unittest_testblock

Co-authored-by: Ezra-Yu <1105212286@qq.com>
Co-authored-by: Ma Zerun <mzr1996@163.com>
  • Loading branch information
3 people authored Sep 29, 2021
1 parent 8b7d38b commit 90496b4
Show file tree
Hide file tree
Showing 37 changed files with 1,459 additions and 1 deletion.
43 changes: 43 additions & 0 deletions configs/_base_/datasets/imagenet_bs64_autoaug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
_base_ = ['./pipelines/auto_aug.py']

# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='AutoAugment', policies={{_base_.auto_increasing_policies}}),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1)),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='accuracy')
96 changes: 96 additions & 0 deletions configs/_base_/datasets/pipelines/auto_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Policy for ImageNet, refers to
# https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py
policy_imagenet = [
[
dict(type='Posterize', bits=4, prob=0.4),
dict(type='Rotate', angle=30., prob=0.6)
],
[
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
dict(type='AutoContrast', prob=0.6)
],
[dict(type='Equalize', prob=0.8),
dict(type='Equalize', prob=0.6)],
[
dict(type='Posterize', bits=5, prob=0.6),
dict(type='Posterize', bits=5, prob=0.6)
],
[
dict(type='Equalize', prob=0.4),
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
],
[
dict(type='Equalize', prob=0.4),
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)
],
[
dict(type='Solarize', thr=256 / 9 * 6, prob=0.6),
dict(type='Equalize', prob=0.6)
],
[dict(type='Posterize', bits=6, prob=0.8),
dict(type='Equalize', prob=1.)],
[
dict(type='Rotate', angle=10., prob=0.2),
dict(type='Solarize', thr=256 / 9, prob=0.6)
],
[
dict(type='Equalize', prob=0.6),
dict(type='Posterize', bits=5, prob=0.4)
],
[
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
dict(type='ColorTransform', magnitude=0., prob=0.4)
],
[
dict(type='Rotate', angle=30., prob=0.4),
dict(type='Equalize', prob=0.6)
],
[dict(type='Equalize', prob=0.0),
dict(type='Equalize', prob=0.8)],
[dict(type='Invert', prob=0.6),
dict(type='Equalize', prob=1.)],
[
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
dict(type='Contrast', magnitude=0.8, prob=1.)
],
[
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
dict(type='ColorTransform', magnitude=0.2, prob=1.)
],
[
dict(type='ColorTransform', magnitude=0.8, prob=0.8),
dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)
],
[
dict(type='Sharpness', magnitude=0.7, prob=0.4),
dict(type='Invert', prob=0.6)
],
[
dict(
type='Shear',
magnitude=0.3 / 9 * 5,
prob=0.6,
direction='horizontal'),
dict(type='Equalize', prob=1.)
],
[
dict(type='ColorTransform', magnitude=0., prob=0.4),
dict(type='Equalize', prob=0.6)
],
[
dict(type='Equalize', prob=0.4),
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
],
[
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
dict(type='AutoContrast', prob=0.6)
],
[dict(type='Invert', prob=0.6),
dict(type='Equalize', prob=1.)],
[
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
dict(type='Contrast', magnitude=0.8, prob=1.)
],
[dict(type='Equalize', prob=0.8),
dict(type='Equalize', prob=0.6)],
]
15 changes: 15 additions & 0 deletions configs/_base_/models/repvgg-A0_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='RepVGG',
arch='A0',
out_indices=(3, ),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))
23 changes: 23 additions & 0 deletions configs/_base_/models/repvgg-B3_lbs-mixup_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='RepVGG',
arch='B3',
out_indices=(3, ),
),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2560,
loss=dict(
type='LabelSmoothLoss',
loss_weight=1.0,
label_smooth_val=0.1,
mode='classy_vision',
num_classes=1000),
topk=(1, 5),
),
train_cfg=dict(
augments=dict(type='BatchMixup', alpha=0.2, num_classes=1000,
prob=1.)))
11 changes: 11 additions & 0 deletions configs/_base_/schedules/imagenet_bs256_200e_coslr_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# optimizer
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=0,
warmup='linear',
warmup_iters=25025,
warmup_ratio=0.25)
runner = dict(type='EpochBasedRunner', max_epochs=200)
48 changes: 48 additions & 0 deletions configs/repvgg/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Repvgg: Making vgg-style convnets great again

## Introduction

<!-- [ALGORITHM] -->

```latex
@inproceedings{ding2021repvgg,
title={Repvgg: Making vgg-style convnets great again},
author={Ding, Xiaohan and Zhang, Xiangyu and Ma, Ningning and Han, Jungong and Ding, Guiguang and Sun, Jian},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={13733--13742},
year={2021}
}
```

## Pretrain model

| Model | Epochs | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :---------: | :----: | :-------------------------------: | :-----------------------------: | :-------: | :-------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| RepVGG-A0 | 120 | 9.11(train) \| 8.31 (deploy) | 1.52 (train) \| 1.36 (deploy) | 72.41 | 90.50 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-A0_4xb64-coslr-120e_in1k.py) \| [config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-A0_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth) |
| RepVGG-A1 | 120 | 14.09 (train) \| 12.79 (deploy) | 2.64 (train) \| 2.37 (deploy) | 74.47 | 91.85 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-A1_4xb64-coslr-120e_in1k.py) \| [config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-A1_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth) |
| RepVGG-A2 | 120 | 28.21 (train) \| 25.5 (deploy) | 5.7 (train) \| 5.12 (deploy) | 76.48 | 93.01 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/masterconfigs/repvgg/repvgg-A2_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-A2_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth) |
| RepVGG-B0 | 120 | 15.82 (train) \| 14.34 (deploy) | 3.42 (train) \| 3.06 (deploy) | 75.14 | 92.42 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B0_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B0_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth) |
| RepVGG-B1 | 120 | 57.42 (train) \| 51.83 (deploy) | 13.16 (train) \| 11.82 (deploy) | 78.37 | 94.11 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B1_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B1_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth) |
| RepVGG-B1g2 | 120 | 45.78 (train) \| 41.36 (deploy) | 9.82 (train) \| 8.82 (deploy) | 77.79 | 93.88 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B1g2_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B1g2_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth) |
| RepVGG-B1g4 | 120 | 39.97 (train) \| 36.13 (deploy) | 8.15 (train) \| 7.32 (deploy) | 77.58 | 93.84 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B1g4_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B1g4_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth) |
| RepVGG-B2 | 120 | 89.02 (train) \| 80.32 (deploy) | 20.46 (train) \| 18.39 (deploy) | 78.78 | 94.42 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B2_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B2_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth) |
| RepVGG-B2g4 | 200 | 61.76 (train) \| 55.78 (deploy) | 12.63 (train) \| 11.34 (deploy) | 79.38 | 94.68 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B2g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B2g4_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth) |
| RepVGG-B3 | 200 | 123.09 (train) \| 110.96 (deploy) | 29.17 (train) \| 26.22 (deploy) | 80.52 | 95.26 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B3_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth) |
| RepVGG-B3g4 | 200 | 83.83 (train) \| 75.63 (deploy) | 17.9 (train) \| 16.08 (deploy) | 80.22 | 95.10 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B3g4_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth) |
| RepVGG-D2se | 200 | 133.33 (train) \| 120.39 (deploy) | 36.56 (train) \| 32.85 (deploy) | 81.81 | 95.94 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-D2se_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth) |

## Reparameterize RepVGG

The checkpoints provided are all in `train` form. Use the reparameterize tool to switch them to more efficient `deploy` form, which not only has fewer parameters but also less calculations.

```bash
python ./tools/convert_models/reparameterize_repvgg.py ${CFG_PATH} ${SRC_CKPT_PATH} ${TARGET_CKPT_PATH}
```

`${CFG_PATH}` is the config file, `${SRC_CKPT_PATH}` is the source chenpoint file, `${TARGET_CKPT_PATH}` is the target deploy weight file path.

To use reparameterized repvgg weight, the config file must switch to [the deploy config files](./configs/repvgg/deploy) as below:

```bash
python ./tools/test.py ${RapVGG_Deploy_CFG} ${CHECK_POINT}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-A0_4xb64-coslr-120e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-A1_4xb64-coslr-120e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-A2_4xb64-coslr-120e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-B0_4xb64-coslr-120e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-B1_4xb64-coslr-120e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-B1g2_4xb64-coslr-120e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-B1g4_4xb64-coslr-120e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-B2_4xb64-coslr-120e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-B2g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'

model = dict(backbone=dict(deploy=True))
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = '../repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'

model = dict(backbone=dict(deploy=True))
Loading

0 comments on commit 90496b4

Please sign in to comment.