Skip to content

Commit

Permalink
Merge 8379188 into 447a398
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Mar 15, 2023
2 parents 447a398 + 8379188 commit 8c95ffc
Show file tree
Hide file tree
Showing 20 changed files with 1,646 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
- [x] [K-Net (NeurIPS'2021)](configs/knet)
- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer)
- [x] [Mask2Former (CVPR'2022)](configs/mask2former)
- [x] [PIDNet (ArXiv'2022)](configs/pidnet)

</details>

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [K-Net (NeurIPS'2021)](configs/knet)
- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer)
- [x] [Mask2Former (CVPR'2022)](configs/mask2former)
- [x] [PIDNet (ArXiv'2022)](configs/pidnet)

</details>

Expand Down
50 changes: 50 additions & 0 deletions configs/pidnet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# PIDNet

> [PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller](https://arxiv.org/pdf/2206.02066.pdf)
## Introduction

<!-- [ALGORITHM] -->

<a href="https://github.com/XuJiacong/PIDNet">Official Repo</a>

<a href="https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/pidnet.py">Code Snippet</a>

## Abstract

<!-- [ABSTRACT] -->

Two-branch network architecture has shown its efficiency and effectiveness for real-time semantic segmentation tasks. However, direct fusion of low-level details and high-level semantics will lead to a phenomenon that the detailed features are easily overwhelmed by surrounding contextual information, namely overshoot in this paper, which limits the improvement of the accuracy of existed two-branch models. In this paper, we bridge a connection between Convolutional Neural Network (CNN) and Proportional-IntegralDerivative (PID) controller and reveal that the two-branch network is nothing but a Proportional-Integral (PI) controller, which inherently suffers from the similar overshoot issue. To alleviate this issue, we propose a novel threebranch network architecture: PIDNet, which possesses three branches to parse the detailed, context and boundary information (derivative of semantics), respectively, and employs boundary attention to guide the fusion of detailed and context branches in final stage. The family of PIDNets achieve the best trade-off between inference speed and accuracy and their test accuracy surpasses all the existed models with similar inference speed on Cityscapes, CamVid and COCO-Stuff datasets. Especially, PIDNet-S achieves 78.6% mIOU with inference speed of 93.2 FPS on Cityscapes test set and 80.1% mIOU with speed of 153.7 FPS on CamVid test set.

<!-- [IMAGE] -->

<div align=center>
<img src="https://raw.githubusercontent.com/XuJiacong/PIDNet/main/figs/pidnet.jpg" width="800"/>
</div>

## Results and models

### Cityscapes

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ----------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| PIDNet | PIDNet-S | 1024x1024 | 120000 | 3.38 | 80.82 | 78.74 | 80.87 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700-bb8e3bcc.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700.json) |
| PIDNet | PIDNet-M | 1024x1024 | 120000 | 5.14 | 71.98 | 80.22 | 82.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452-f9bcdbf3.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452.json) |
| PIDNet | PIDNet-L | 1024x1024 | 120000 | 5.83 | 60.06 | 80.89 | 82.37 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514-0783ca6b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514.json) |

## Notes

The pretrained weights in config files are converted from [the official repo](https://github.com/XuJiacong/PIDNet#models).

## Citation

```bibtex
@misc{xu2022pidnet,
title={PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller},
author={Jiacong Xu and Zixiang Xiong and Shankar P. Bhattacharyya},
year={2022},
eprint={2206.02066},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
10 changes: 10 additions & 0 deletions configs/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_base_ = './pidnet-s_2xb6-120k_1024x1024-cityscapes.py'
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-l_imagenet1k_20230306-67889109.pth' # noqa
model = dict(
backbone=dict(
channels=64,
ppm_channels=112,
num_stem_blocks=3,
num_branch_blocks=4,
init_cfg=dict(checkpoint=checkpoint_file)),
decode_head=dict(in_channels=256, channels=256))
5 changes: 5 additions & 0 deletions configs/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = './pidnet-s_2xb6-120k_1024x1024-cityscapes.py'
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-m_imagenet1k_20230306-39893c52.pth' # noqa
model = dict(
backbone=dict(channels=64, init_cfg=dict(checkpoint=checkpoint_file)),
decode_head=dict(in_channels=256))
113 changes: 113 additions & 0 deletions configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
_base_ = [
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py'
]

# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
# Licensed under the MIT License
class_weight = [
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
1.0507
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth' # noqa
crop_size = (1024, 1024)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size=crop_size)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='PIDNet',
in_channels=3,
channels=32,
ppm_channels=96,
num_stem_blocks=2,
num_branch_blocks=3,
align_corners=False,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
decode_head=dict(
type='PIDHead',
in_channels=128,
channels=128,
num_classes=19,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True),
align_corners=True,
loss_decode=[
dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=class_weight,
loss_weight=0.4),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=1.0),
dict(type='BoundaryLoss', loss_weight=20.0),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=1.0)
]),
train_cfg=dict(),
test_cfg=dict(mode='whole'))

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='GenerateEdge', edge_width=4),
dict(type='PackSegInputs')
]
train_dataloader = dict(batch_size=6, dataset=dict(pipeline=train_pipeline))

iters = 120000
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=0,
power=0.9,
begin=0,
end=iters,
by_epoch=False)
]
# training schedule for 120k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=iters // 10),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

randomness = dict(seed=304)
81 changes: 81 additions & 0 deletions configs/pidnet/pidnet.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
Collections:
- Name: PIDNet
Metadata:
Training Data:
- Cityscapes
Paper:
URL: https://arxiv.org/pdf/2206.02066.pdf
Title: 'PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller'
README: configs/pidnet/README.md
Code:
URL: https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/pidnet.py
Version: dev-1.x
Converted From:
Code: https://github.com/XuJiacong/PIDNet
Models:
- Name: pidnet-s_2xb6-120k_1024x1024-cityscapes
In Collection: PIDNet
Metadata:
backbone: PIDNet-S
crop size: (1024,1024)
lr schd: 120000
inference time (ms/im):
- value: 12.37
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1024,1024)
Training Memory (GB): 3.38
Results:
- Task: Semantic Segmentation
Dataset: Cityscapes
Metrics:
mIoU: 78.74
mIoU(ms+flip): 80.87
Config: configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700-bb8e3bcc.pth
- Name: pidnet-m_2xb6-120k_1024x1024-cityscapes
In Collection: PIDNet
Metadata:
backbone: PIDNet-M
crop size: (1024,1024)
lr schd: 120000
inference time (ms/im):
- value: 13.89
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1024,1024)
Training Memory (GB): 5.14
Results:
- Task: Semantic Segmentation
Dataset: Cityscapes
Metrics:
mIoU: 80.22
mIoU(ms+flip): 82.05
Config: configs/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452-f9bcdbf3.pth
- Name: pidnet-l_2xb6-120k_1024x1024-cityscapes
In Collection: PIDNet
Metadata:
backbone: PIDNet-L
crop size: (1024,1024)
lr schd: 120000
inference time (ms/im):
- value: 16.65
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1024,1024)
Training Memory (GB): 5.83
Results:
- Task: Semantic Segmentation
Dataset: Cityscapes
Metrics:
mIoU: 80.89
mIoU(ms+flip): 82.37
Config: configs/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514-0783ca6b.pth
3 changes: 2 additions & 1 deletion mmseg/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .mit import MixVisionTransformer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .pidnet import PIDNet
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt
Expand All @@ -26,5 +27,5 @@
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE'
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet'
]
Loading

0 comments on commit 8c95ffc

Please sign in to comment.