Skip to content

Commit

Permalink
Merge 83da311 into dd47cef
Browse files Browse the repository at this point in the history
  • Loading branch information
MengzhangLI committed Mar 15, 2023
2 parents dd47cef + 83da311 commit d42a4e8
Show file tree
Hide file tree
Showing 15 changed files with 1,176 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)

</details>

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)

</details>

Expand Down
63 changes: 63 additions & 0 deletions configs/segnext/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SegNeXt

> [SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation](https://arxiv.org/abs/2209.08575)
## Introduction

<!-- [ALGORITHM] -->

<a href="https://github.com/visual-attention-network/segnext">Official Repo</a>

<a href="https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/mscan.py#L328">Code Snippet</a>

## Abstract

<!-- [ABSTRACT] -->

We present SegNeXt, a simple convolutional network architecture for semantic segmentation. Recent transformer-based models have dominated the field of semantic segmentation due to the efficiency of self-attention in encoding spatial information. In this paper, we show that convolutional attention is a more efficient and effective way to encode contextual information than the self-attention mechanism in transformers. By re-examining the characteristics owned by successful segmentation models, we discover several key components leading to the performance improvement of segmentation models. This motivates us to design a novel convolutional attention network that uses cheap convolutional operations. Without bells and whistles, our SegNeXt significantly improves the performance of previous state-of-the-art methods on popular benchmarks, including ADE20K, Cityscapes, COCO-Stuff, Pascal VOC, Pascal Context, and iSAID. Notably, SegNeXt outperforms EfficientNet-L2 w/ NAS-FPN and achieves 90.6% mIoU on the Pascal VOC 2012 test leaderboard using only 1/10 parameters of it. On average, SegNeXt achieves about 2.0% mIoU improvements compared to the state-of-the-art methods on the ADE20K datasets with the same or fewer computations. Code is available at [this https URL](https://github.com/uyzhang/JSeg) (Jittor) and [this https URL](https://github.com/Visual-Attention-Network/SegNeXt) (Pytorch).

<!-- [IMAGE] -->

<div align=center>
<img src="https://user-images.githubusercontent.com/24582831/215688018-5d4c8366-7793-4fdf-9397-960a09fac951.png" width="70%"/>
</div>

## Results and models

### ADE20K

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| SegNeXt | MSCAN-T | 512x512 | 160000 | 17.88 | 52.38 | 41.50 | 42.59 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244.log.json) |
| SegNeXt | MSCAN-S | 512x512 | 160000 | 21.47 | 42.27 | 44.16 | 45.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014.log.json) |
| SegNeXt | MSCAN-B | 512x512 | 160000 | 31.03 | 35.15 | 48.03 | 49.68 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053.log.json) |
| SegNeXt | MSCAN-L | 512x512 | 160000 | 43.32 | 22.91 | 50.99 | 52.10 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055.log.json) |

Note:

- When we integrated SegNeXt into MMSegmentation, we modified some layers' names to make them more precise and concise without changing the model architecture. Therefore, the keys of pre-trained weights are different from the [original weights](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/), but don't worry about these changes. we have converted them and uploaded the checkpoints, you might find URL of pre-trained checkpoints in config files and can use them directly for training.

- The total batch size is 16. We trained for SegNeXt with a single GPU as the performance degrades significantly when using`SyncBN` (mainly in `OverlapPatchEmbed` modules of `MSCAN`) of PyTorch 1.9.

- There will be subtle differences when model testing as Non-negative Matrix Factorization (NMF) in `LightHamHead` will be initialized randomly. To control this randomness, please set the random seed when model testing. You can modify [`./tools/test.py`](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/tools/test.py) like:

```python
def main():
from mmengine.runner import seg_random_seed
random_seed = xxx # set random seed recorded in training log
set_random_seed(random_seed, deterministic=False)
...
```

- This model performance is sensitive to the seed values used, please refer to the log file for the specific settings of the seed. If you choose a different seed, the results might differ from the table results. Take SegNeXt Large for example, its results range from 49.60 to 51.0.

## Citation

```bibtex
@article{guo2022segnext,
title={SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation},
author={Guo, Meng-Hao and Lu, Cheng-Ze and Hou, Qibin and Liu, Zhengning and Cheng, Ming-Ming and Hu, Shi-Min},
journal={arXiv preprint arXiv:2209.08575},
year={2022}
}
```
103 changes: 103 additions & 0 deletions configs/segnext/segnext.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
Collections:
- Name: SegNeXt
Metadata:
Training Data:
- ADE20K
Paper:
URL: https://arxiv.org/abs/2209.08575
Title: 'SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation'
README: configs/segnext/README.md
Code:
URL: https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/mscan.py#L328
Version: dev-1.x
Converted From:
Code: https://github.com/visual-attention-network/segnext
Models:
- Name: segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512
In Collection: SegNeXt
Metadata:
backbone: MSCAN-T
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 19.09
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 17.88
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 41.5
mIoU(ms+flip): 42.59
Config: configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth
- Name: segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512
In Collection: SegNeXt
Metadata:
backbone: MSCAN-S
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 23.66
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 21.47
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 44.16
mIoU(ms+flip): 45.81
Config: configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth
- Name: segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512
In Collection: SegNeXt
Metadata:
backbone: MSCAN-B
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 28.45
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 31.03
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 48.03
mIoU(ms+flip): 49.68
Config: configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth
- Name: segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512
In Collection: SegNeXt
Metadata:
backbone: MSCAN-L
crop size: (512,512)
lr schd: 160000
inference time (ms/im):
- value: 43.65
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (512,512)
Training Memory (GB): 43.32
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 50.99
mIoU(ms+flip): 52.1
Config: configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth
28 changes: 28 additions & 0 deletions configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py'

# model settings
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_b_20230227-3ab7d230.pth' # noqa
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 3, 12, 3],
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
drop_path_rate=0.1,
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=512,
ham_channels=512,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
27 changes: 27 additions & 0 deletions configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py'
# model settings
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_l_20230227-cef260d4.pth' # noqa
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[3, 5, 27, 3],
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
drop_path_rate=0.3,
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=1024,
ham_channels=1024,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
27 changes: 27 additions & 0 deletions configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py'
# model settings
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_s_20230227-f33ccdf2.pth' # noqa
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
embed_dims=[64, 128, 320, 512],
depths=[2, 2, 4, 2],
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[128, 320, 512],
in_index=[1, 2, 3],
channels=256,
ham_channels=256,
ham_kwargs=dict(MD_R=16),
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
84 changes: 84 additions & 0 deletions configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
_base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py',
'../_base_/datasets/ade20k.py'
]
# model settings
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth' # noqa
ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
crop_size = (512, 512)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size=(512, 512),
test_cfg=dict(size_divisor=32))
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='MSCAN',
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
embed_dims=[32, 64, 160, 256],
mlp_ratios=[8, 8, 4, 4],
drop_rate=0.0,
drop_path_rate=0.1,
depths=[3, 3, 5, 2],
attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN', requires_grad=True)),
decode_head=dict(
type='LightHamHead',
in_channels=[64, 160, 256],
in_index=[1, 2, 3],
channels=256,
ham_channels=256,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=ham_norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
ham_kwargs=dict(
MD_S=1,
MD_R=16,
train_steps=6,
eval_steps=7,
inv_t=100,
rand_init=True)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

# dataset settings
train_dataloader = dict(batch_size=16)

# optimizer
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'pos_block': dict(decay_mult=0.),
'norm': dict(decay_mult=0.),
'head': dict(lr_mult=10.)
}))

param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
power=1.0,
begin=1500,
end=160000,
eta_min=0.0,
by_epoch=False,
)
]
3 changes: 2 additions & 1 deletion mmseg/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .mit import MixVisionTransformer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .mscan import MSCAN
from .pidnet import PIDNet
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
Expand All @@ -27,5 +28,5 @@
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet'
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN'
]
Loading

0 comments on commit d42a4e8

Please sign in to comment.