Skip to content

Commit

Permalink
[Feature]: CAE Supported (#284)
Browse files Browse the repository at this point in the history
* [Feature]: Add mc

* [Feature]: Add dataset of CAE

* [Feature]: Init version of CAE

* [Feature]: Add mc

* [Fix]: Change beta to (0.9, 0.999)

* [Fix]: New feature

* [Fix]: Decouple the qkv bias

* [Feature]: Decouple qkv bias in MultiheadAttention

* [Feature]: New mask generator

* [Fix]: Fix TransformEncoderLayer bug

* [Feature]: Add MAE CAE linear prob

* [Fix]: Fix config

* [Fix]: Delete redundant mc

* [Fix]: Add init value in mim cls vit

* [Fix]: Fix cae ft config

* [Fix]: Delete repeated init_values

* [Fix]: Change bs from 64 to 128 in CAE ft

* [Fix]: Add mc in cae pt

* [Fix]: Fix momemtum update bug

* [Fix]: Add no weight_decay for gamma

* [Feature]: Add mc for cae pt

* [Fix]: Delete mc

* [Fix]: Delete redundant files

* [Fix]: Fix lint

* [Feature]: Add docstring to algo, backbone, neck and head

* [Fix]: Fix lint

* [Fix]: network

* [Feature]: Add docstrings for network blocks

* [Feature]: Add docstring to ToTensor

* [Feature]: Add docstring to transoform

* [Fix]: Add type hint to BEiTMaskGenerator

* [Fix]: Fix lint

* [Fix]: Add copyright to dalle_e

* [Fix]: Fix BlockwiseMaskGenerator

* [Feature]: Add UT for CAE

* [Fix]: Fix dalle state_dict path not existed bug

* [Fix]: Delete file_client_args related code

* [Fix]: Remove redundant code

* [Refactor]: Add fp16 to the name of cae pre-train config

* [Refactor]: Use FFN from mmcv

* [Refactor]: Change network_blocks to trasformer_blocks

* [Fix]: Fix mask generator name bug

* [Fix]: cae pre-train config bug

* [Fix]: Fix docstring grammar

* [Fix]: Fix mc related code

* [Fix]: Add object parent to transform

* [Fix]: Delete unnecessary modification

* [Fix]: Change blockwisemask generator to simmim mask generator

* [Refactor]: Change cae mae pretrain vit to cae mae vit

* [Refactor]: Change lamb to lambd

* [Fix]: Remove blank line

* [Fix]: Fix lint

* [Fix]: Fix UT

* [Fix]: Delete modification to swin

* [Fix]: Fix lint

* [Feature]: Add README and metafile

* [Feature]: Update index.rst

* [Fix]: Update model_zoo

* [Fix]: Change MAE to CAE in algorithm

* [Fix]: Change SimMIMMaskGenerator to CAEMaskGenerator

* [Fix]: Fix model zoo

* [Fix]: Change to dalle_encoder

* [Feature]: Add download link for dalle

* [Fix]: Fix lint

* [Fix]: Fix UT

* [Fix]: Update metafile

* [Fix]: Change b to base

* [Feature]: Add dalle download link in warning

* [Fix] add arxiv link in readme

Co-authored-by: Jiahao Xie <52497952+Jiahao000@users.noreply.github.com>
  • Loading branch information
YuanLiuuuuuu and Jiahao000 authored Apr 28, 2022
1 parent 955632e commit c532a11
Show file tree
Hide file tree
Showing 37 changed files with 1,988 additions and 29 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Supported algorithms:
- [x] [MoCo v3 (ICCV'2021)](https://arxiv.org/abs/2104.02057)
- [x] [MAE](https://arxiv.org/abs/2111.06377)
- [x] [SimMIM](https://arxiv.org/abs/2111.09886)
- [x] [CAE](https://arxiv.org/abs/2202.03026)

More algorithms are in our plan.

Expand Down
6 changes: 2 additions & 4 deletions configs/benchmarks/classification/_base_/datasets/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,15 @@
data_source=dict(
type=data_source,
data_prefix='data/imagenet/train',
ann_file='data/imagenet/meta/train.txt',
),
ann_file='data/imagenet/meta/train.txt'),
pipeline=train_pipeline,
prefetch=prefetch),
val=dict(
type=dataset_type,
data_source=dict(
type=data_source,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
),
ann_file='data/imagenet/meta/val.txt'),
pipeline=test_pipeline,
prefetch=prefetch))
evaluation = dict(interval=10, topk=(1, 5))
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
_base_ = 'vit-base-p16_ft-8xb128-coslr-100e_in1k.py'

# model
model = dict(backbone=dict(use_window=True, init_values=0.1))

# optimizer
optimizer = dict(lr=8e-3)

# learning policy
lr_config = dict(warmup_iters=5)

# dataset
img_norm_cfg = dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
train_pipeline = [
dict(
type='RandomAug',
input_size=224,
color_jitter=0.4,
auto_augment='rand-m9-mstd0.5-inc1',
interpolation='bicubic',
re_prob=0.25,
re_mode='pixel',
re_count=1,
mean=(0.5, 0.5, 0.5),
std=(0.5, 0.5, 0.5))
]
test_pipeline = [
dict(type='Resize', size=256, interpolation=3),
dict(type='CenterCrop', size=224),
dict(type='ToTensor'),
dict(type='Normalize', **img_norm_cfg)
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
samples_per_gpu=128)

find_unused_parameters = True
40 changes: 40 additions & 0 deletions configs/selfsup/_base_/datasets/imagenet_cae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# dataset settings
data_source = 'ImageNet'
dataset_type = 'SingleViewDataset'
img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_pipeline = [
dict(type='RandomHorizontalFlip', p=0.5),
dict(
type='RandomResizedCropAndInterpolationWithTwoPic',
size=224,
second_size=112,
interpolation='bicubic',
second_interpolation='lanczos',
scale=(0.08, 1.0)),
]

# prefetch
prefetch = False
if not prefetch:
train_pipeline.extend([dict(type='ToTensor')])

train_pipeline.append(
dict(
type='BEiTMaskGenerator',
input_size=(14, 14),
num_masking_patches=75,
max_num_patches=None,
min_num_patches=16))

# dataset summary
data = dict(
samples_per_gpu=256,
workers_per_gpu=8,
train=dict(
type=dataset_type,
data_source=dict(
type=data_source,
data_prefix='data/imagenet/train',
ann_file='data/imagenet/meta/train.txt'),
pipeline=train_pipeline,
prefetch=prefetch))
17 changes: 17 additions & 0 deletions configs/selfsup/_base_/models/cae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# model settings
model = dict(
type='CAE',
backbone=dict(type='CAEViT', arch='b', patch_size=16, init_values=0.1),
neck=dict(
type='CAENeck',
patch_size=16,
embed_dims=768,
num_heads=12,
regressor_depth=4,
decoder_depth=4,
mlp_ratio=4,
init_values=0.1,
),
head=dict(
type='CAEHead', tokenizer_path='cae_ckpt/dalle_encoder.pth', lambd=2),
base_momentum=0.0)
42 changes: 42 additions & 0 deletions configs/selfsup/cae/RAEDME.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# CAE

> [Context Autoencoder for Self-Supervised Representation Learning](https://arxiv.org/abs/2202.03026)
<!-- [ALGORITHM] -->

## Abstract

We present a novel masked image modeling (MIM) approach, context autoencoder (CAE), for self-supervised learning. We randomly partition the image into two sets: visible patches and masked patches. The CAE architecture consists of: (i) an encoder that takes visible patches as input and outputs their latent representations, (ii) a latent context regressor that predicts the masked patch representations from the visible patch representations that are not updated in this regressor, (iii) a decoder that takes the estimated masked patch representations as input and makes predictions for the masked patches, and (iv) an alignment module that aligns the masked patch representation estimation with the masked patch representations computed from the encoder. In comparison to previous MIM methods that couple the encoding and decoding roles, e.g., using a single module in BEiT, our approach attempts to separate the encoding role (content understanding) from the decoding role (making predictions for masked patches) using different modules, improving the content understanding capability. In addition, our approach makes predictions from the visible patches to the masked patches in the latent representation space that is expected to take on semantics. In addition, we present the explanations about why contrastive pretraining and supervised pretraining perform similarly and why MIM potentially performs better. We demonstrate the effectiveness of our CAE through superior transfer performance in downstream tasks: semantic segmentation, and object detection and instance segmentation.

<div align="center">
<img src="https://user-images.githubusercontent.com/30762564/165459947-6c6ef13c-0593-4765-b44e-6da0a079802a.png" width="40%"/>
</div>


## Prerequisite

Create a new folder ``cae_ckpt`` under the root directory and download the
[weights](https://download.openmmlab.com/mmselfsup/cae/dalle_encoder.pth) for ``dalle`` encoder to that folder
## Models and Benchmarks

Here, we report the results of the model, which is pre-trained on ImageNet-1k
for 300 epochs, the details are below:



| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download |
| :------: | :-------------: | :---------------: | :-------------------------------------------------: | :---------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| ViT-B/16 | 300 | 83.2 | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e-rpe_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) |


## Citation

```bibtex
@article{CAE,
title={Context Autoencoder for Self-Supervised Representation Learning},
author={Xiaokang Chen, Mingyu Ding, Xiaodi Wang, Ying Xin, Shentong Mo,
Yunhao Wang, Shumin Han, Ping Luo, Gang Zeng, Jingdong Wang},
journal={ArXiv},
year={2022}
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = 'cae_vit-base-p16_32xb64-fp16-coslr-300e_in1k.py'

# dataset
data = dict(samples_per_gpu=128, workers_per_gpu=8)
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
_base_ = [
'../_base_/models/cae.py',
'../_base_/datasets/imagenet_cae.py',
'../_base_/schedules/adamw_coslr-200e_in1k.py',
'../_base_/default_runtime.py',
]

# dataset
data = dict(samples_per_gpu=64, workers_per_gpu=8)

# optimizer
optimizer = dict(
lr=1.5e-3,
paramwise_options={
'norm': dict(weight_decay=0.),
'bias': dict(weight_decay=0.),
'gamma': dict(weight_decay=0.)
},
betas=(0.9, 0.999))

# learning policy
lr_config = dict(
policy='StepFixCosineAnnealing',
min_lr=1e-5,
warmup='linear',
warmup_iters=10,
warmup_ratio=1e-4,
warmup_by_epoch=True,
by_epoch=False)

# schedule
runner = dict(max_epochs=300)

# clip gradient
optimizer_config = dict(grad_clip=dict(max_norm=3.0))

# mixed precision
fp16 = dict(loss_scale='dynamic')

# runtime
checkpoint_config = dict(interval=1, max_keep_ckpts=2, out_dir='')
persistent_workers = True
log_config = dict(
interval=100, hooks=[
dict(type='TextLoggerHook'),
])

find_unused_parameters = True
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = 'cae_vit-base-p16_16xb128-fp16-coslr-300e_in1k.py'

# dataset
data = dict(samples_per_gpu=256, workers_per_gpu=8)
27 changes: 27 additions & 0 deletions configs/selfsup/cae/metafile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Collections:
- Name: CAE
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- AdamW
Training Resources: 8x A100-80G GPUs
Architecture:
- ViT
Paper:
URL: https://arxiv.org/abs/2202.03026
Title: "Context Autoencoder for Self-Supervised Representation Learning"
README: configs/selfsup/cae/README.md

Models:
- Name: cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k
In Collection: CAE
Metadata:
Epochs: 300
Batch Size: 2048
Results:
- Task: Self-Supervised Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.2
Config: configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth
42 changes: 42 additions & 0 deletions docs/en/algorithms/cae.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# CAE

> [Context Autoencoder for Self-Supervised Representation Learning](https://arxiv.org/abs/2202.03026)
<!-- [ALGORITHM] -->

## Abstract

We present a novel masked image modeling (MIM) approach, context autoencoder (CAE), for self-supervised learning. We randomly partition the image into two sets: visible patches and masked patches. The CAE architecture consists of: (i) an encoder that takes visible patches as input and outputs their latent representations, (ii) a latent context regressor that predicts the masked patch representations from the visible patch representations that are not updated in this regressor, (iii) a decoder that takes the estimated masked patch representations as input and makes predictions for the masked patches, and (iv) an alignment module that aligns the masked patch representation estimation with the masked patch representations computed from the encoder. In comparison to previous MIM methods that couple the encoding and decoding roles, e.g., using a single module in BEiT, our approach attempts to separate the encoding role (content understanding) from the decoding role (making predictions for masked patches) using different modules, improving the content understanding capability. In addition, our approach makes predictions from the visible patches to the masked patches in the latent representation space that is expected to take on semantics. In addition, we present the explanations about why contrastive pretraining and supervised pretraining perform similarly and why MIM potentially performs better. We demonstrate the effectiveness of our CAE through superior transfer performance in downstream tasks: semantic segmentation, and object detection and instance segmentation.

<div align="center">
<img src="https://user-images.githubusercontent.com/30762564/165459947-6c6ef13c-0593-4765-b44e-6da0a079802a.png" width="40%"/>
</div>


## Prerequisite

Create a new folder ``cae_ckpt`` under the root directory and download the
[weights](https://download.openmmlab.com/mmselfsup/cae/dalle_encoder.pth) for ``dalle`` encoder to that folder
## Models and Benchmarks

Here, we report the results of the model, which is pre-trained on ImageNet-1k
for 300 epochs, the details are below:



| Backbone | Pre-train epoch | Fine-tuning Top-1 | Pre-train Config | Fine-tuning Config | Download |
| :------: | :-------------: | :---------------: | :-------------------------------------------------: | :---------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| ViT-B/16 | 300 | 83.2 | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [config](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/vit-base-p16_ft-8xb128-coslr-100e-rpe_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) |


## Citation

```bibtex
@article{CAE,
title={Context Autoencoder for Self-Supervised Representation Learning},
author={Xiaokang Chen, Mingyu Ding, Xiaodi Wang, Ying Xin, Shentong Mo,
Yunhao Wang, Shumin Han, Ping Luo, Gang Zeng, Jingdong Wang},
journal={ArXiv},
year={2022}
}
```
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Welcome to MMSelfSup's documentation!
algorithms/mae.md
algorithms/simmim.md
algorithms/barlowtwins.md
algorithms/cae.md


.. toctree::
Expand Down
12 changes: 7 additions & 5 deletions docs/en/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ All models and part of benchmark results are recorded below.
| [BarlowTwins](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/barlowtwins/README.md) | [barlowtwins_resnet50_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k_20220419-5ae15f89.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k_20220413_111555.log.json) |
| [MoCo v3](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/README.md) | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | [model](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220225-e31238dd.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/moco/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224_20220222_160222.log.json) |
| [MAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/README.md) | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k-224_20220223-85be947b.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/mae/mae_vit-base-p16_8xb512-coslr-300e_in1k-224_20220210_140925.log.json) |
| [SimMIM](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/README.md) | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | [model](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.log.json) |
| [SimMIM](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/README.md) | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | [model](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192_20220316-1d090125.log.json) |
| [CAE](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/README.md) | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | [model](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.pth) &#124; [log](https://download.openmmlab.com/mmselfsup/cae/cae_vit-base-p16_16xb256-coslr-300e_in1k-224_20220427-4c786349.log.json) |

Remarks:

Expand Down Expand Up @@ -57,10 +58,11 @@ If not specified, we use linear evaluation setting from [MoCo](http://openaccess
| MoCo v3 | [mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mocov3/mocov3_vit-small-p16_32xb128-fp16-coslr-300e_in1k-224.py) | MoCo v3 paper setting | 73.19 |

### ImageNet Fine-tuning
| Algorithm | Config | Remarks | Top-1 (%) |
| --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | --------- |
| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |
| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 |
| Algorithm | Config | Remarks | Top-1 (%) |
| --------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | --------- |
| MAE | [mae_vit-base-p16_8xb512-coslr-400e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py) | | 83.1 |
| SimMIM | [simmim_swin-base_16xb128-coslr-100e_in1k-192](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/simmim/simmim_swin-base_16xb128-coslr-100e_in1k-192.py) | | 82.9 |
| CAE | [cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/cae/cae_vit-base-p16_8xb256-fp16-coslr-300e_in1k.py) | | 83.2 |

### COCO17 Object Detection and Instance Segmentation

Expand Down
Loading

0 comments on commit c532a11

Please sign in to comment.