Skip to content

Commit

Permalink
[features] Support stdc (#284)
Browse files Browse the repository at this point in the history
* add stdc semantic segmentation algorithm
  • Loading branch information
haiasd authored Feb 16, 2023
1 parent 2fe73ee commit 26cd12a
Show file tree
Hide file tree
Showing 24 changed files with 1,350 additions and 8 deletions.
198 changes: 198 additions & 0 deletions configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
_base_ = ['configs/base.py']

# warning batch_size need >= 2
# model
norm_cfg = dict(type='BN', requires_grad=True)
model = dict(
type='EncoderDecoder',
backbone=dict(
type='STDCContextPathNet',
backbone_cfg=dict(
type='STDCNet',
stdc_type='STDCNet1',
in_channels=3,
channels=(32, 64, 256, 512, 1024),
bottleneck_type='cat',
num_convs=4,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
with_final_conv=False),
last_in_channels=(1024, 512),
out_channels=128,
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)),
decode_head=dict(
type='FCNHead',
in_channels=256,
channels=256,
num_convs=1,
num_classes=19,
in_index=3,
concat_input=False,
dropout_ratio=0.1,
norm_cfg=norm_cfg,
align_corners=True,
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=128,
channels=64,
num_convs=1,
num_classes=19,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000),
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='STDCHead',
in_channels=256,
channels=64,
num_convs=1,
num_classes=2,
boundary_threshold=0.1,
in_index=0,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=True,
loss_decode=[
dict(
type='CrossEntropyLoss',
loss_name='loss_ce',
use_sigmoid=True,
loss_weight=1.0),
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0)
]),
],
train_cfg=dict(),
test_cfg=dict(mode='whole'),
pretrained=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc1_easycv.pth'
)

# dataset
CLASSES = [
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car',
'truck', 'bus', 'train', 'motorcycle', 'bicycle'
]
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 1024)

train_pipeline = [
dict(type='MMResize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='SegRandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='MMRandomFlip', flip_ratio=0.5),
dict(type='MMPhotoMetricDistortion'),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='MMPad', size=crop_size),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
keys=['img', 'gt_semantic_seg'],
meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape',
'pad_shape', 'scale_factor', 'flip', 'flip_direction',
'img_norm_cfg')),
]

test_pipeline = [
dict(
type='MMMultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='MMResize', keep_ratio=True),
dict(type='MMRandomFlip'),
dict(type='MMNormalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(
type='Collect',
keys=['img'],
meta_keys=('filename', 'ori_filename', 'ori_shape',
'img_shape', 'pad_shape', 'scale_factor', 'flip',
'flip_direction', 'img_norm_cfg')),
])
]
dataset_type = 'SegDataset'
data_root = '../Cityscapes/'

train_img_root = data_root + 'leftImg8bit/train/'
train_label_root = data_root + 'gtFine/train/'

val_img_root = data_root + 'leftImg8bit/val/'
val_label_root = data_root + 'gtFine/val/'
data = dict(
imgs_per_gpu=6,
workers_per_gpu=4,
persistent_workers=True,
train=dict(
type=dataset_type,
ignore_index=255,
data_source=dict(
type='SegSourceCityscapes',
img_root=train_img_root,
label_root=train_label_root,
classes=CLASSES),
pipeline=train_pipeline),
val=dict(
imgs_per_gpu=1,
ignore_index=255,
type=dataset_type,
data_source=dict(
type='SegSourceCityscapes',
img_root=val_img_root,
label_root=val_label_root,
classes=CLASSES),
pipeline=test_pipeline))

# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optimizer_config = dict()
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-4,
warmup='linear',
warmup_iters=10,
warmup_ratio=0.0001,
warmup_by_epoch=True,
by_epoch=False)

# runtime settings
total_epochs = 1290
checkpoint_config = dict(interval=10)
eval_config = dict(interval=10, gpu_collect=False)
eval_pipelines = [
dict(
mode='test',
evaluators=[
dict(
type='SegmentationEvaluator',
classes=CLASSES,
metric_names=['mIoU'])
],
)
]

# export config
export = dict(export_neck=True)
checkpoint_sync_export = True
7 changes: 7 additions & 0 deletions configs/segmentation/stdc/stdc2_cityscape_8xb6_e1290.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = ['configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py']

model = dict(
backbone=dict(backbone_cfg=dict(stdc_type='STDCNet2')),
pretrained=
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc2_easycv.pth'
)
6 changes: 6 additions & 0 deletions docs/source/model_zoo_seg.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ Pretrained on **Pascal VOC 2012 + Aug**.
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| upernet_r50 | [upernet_r50_512x512_8xb4_60e_voc12aug](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/upernet/upernet_r50_512x512_8xb4_60e_voc12aug.py) | 23M/66M | 5.5 | 282.9ms | 76.59 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/epoch_60.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/upernet_r50/20220706_114712.log.json) |

## STDC
trained on **Cityscapes**.
| Algorithm | Config | Params<br/>(backbone/total) | Train memory<br/>(GB) | inference time(V100)<br/>(ms/img) | mIoU | Download |
| ---------- | ------------------------------------------------------------ | ------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ |
| STDC1 | [stdc1_cityscape_8xb6_e1290](https://github.com/alibaba/EasyCV/tree/master/configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py) | 7.7M/8.5M | 4.5 | 11.9ms | 75.4 | [model](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/stdc1_cityscapes/epoch_1250.pth) - [log](http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/stdc1_cityscapes/20230214_173123.log.json) |

## Mask2former

### Instance Segmentation on COCO
Expand Down
3 changes: 2 additions & 1 deletion easycv/datasets/segmentation/data_sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .cityscapes import SegSourceCityscapes
from .coco import SegSourceCoco, SegSourceCoco2017
from .coco_stuff import SegSourceCocoStuff10k, SegSourceCocoStuff164k
from .raw import SegSourceRaw
Expand All @@ -7,5 +8,5 @@
__all__ = [
'SegSourceRaw', 'SegSourceVoc2010', 'SegSourceVoc2007', 'SegSourceVoc2012',
'SegSourceCoco', 'SegSourceCoco2017', 'SegSourceCocoStuff164k',
'SegSourceCocoStuff10k'
'SegSourceCocoStuff10k', 'SegSourceCityscapes'
]
129 changes: 129 additions & 0 deletions easycv/datasets/segmentation/data_sources/cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import logging
import os
import subprocess

import numpy as np

from easycv.datasets.registry import DATASOURCES
from easycv.file import io
from easycv.file.image import load_image as _load_img
from .raw import SegSourceRaw

try:
import cityscapesscripts.helpers.labels as CSLabels
except ModuleNotFoundError as e:
res = subprocess.call('pip install cityscapesscripts', shell=True)
if res != 0:
info_string = (
'\n\nAuto install failed! Please install cityscapesscripts with the following commands :\n'
'\t`pip install cityscapesscripts`\n')
raise ModuleNotFoundError(info_string)


def load_seg_map_cityscape(seg_path, reduce_zero_label):
gt_semantic_seg = _load_img(seg_path, mode='P')
gt_semantic_seg_copy = gt_semantic_seg.copy()
for labels in CSLabels.labels:
gt_semantic_seg_copy[gt_semantic_seg == labels.id] = labels.trainId

return {'gt_semantic_seg': gt_semantic_seg_copy}


@DATASOURCES.register_module
class SegSourceCityscapes(SegSourceRaw):
"""Cityscapes datasource
"""
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole',
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky',
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle')

PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0],
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60],
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100],
[0, 80, 100], [0, 0, 230], [119, 11, 32]]

def __init__(self,
img_suffix='_leftImg8bit.png',
label_suffix='_gtFine_labelIds.png',
**kwargs):
super(SegSourceCityscapes, self).__init__(
img_suffix=img_suffix, label_suffix=label_suffix, **kwargs)

def __getitem__(self, idx):
result_dict = self.samples_list[idx]
load_success = True
try:
# avoid data cache from taking up too much memory
if not self.cache_at_init and not self.cache_on_the_fly:
result_dict = copy.deepcopy(result_dict)

if not self.cache_at_init:
if result_dict.get('img', None) is None:
img = _load_img(result_dict['filename'], mode='BGR')
result = {
'img': img.astype(np.float32),
'img_shape': img.shape, # h, w, c
'ori_shape': img.shape,
}
result_dict.update(result)
if result_dict.get('gt_semantic_seg', None) is None:
result_dict.update(
load_seg_map_cityscape(
result_dict['seg_filename'],
reduce_zero_label=self.reduce_zero_label))
if self.cache_on_the_fly:
self.samples_list[idx] = result_dict
result_dict = self.post_process_fn(copy.deepcopy(result_dict))
self._retry_count = 0
except Exception as e:
logging.warning(e)
load_success = False

if not load_success:
logging.warning(
'Something wrong with current sample %s,Try load next sample...'
% result_dict.get('filename', ''))
self._retry_count += 1
if self._retry_count >= self._max_retry_num:
raise ValueError('All samples failed to load!')

result_dict = self[(idx + 1) % self.num_samples]

return result_dict

def get_source_iterator(self):

self.img_files = [
os.path.join(self.img_root, i)
for i in io.listdir(self.img_root, recursive=True)
if i.endswith(self.img_suffix[0])
]

self.label_files = []
for img_path in self.img_files:
self.img_root = os.path.join(self.img_root, '')
img_name = img_path.replace(self.img_root,
'')[:-len(self.img_suffix[0])]
find_label_path = False
for label_format in self.label_suffix:
lable_path = os.path.join(self.label_root,
img_name + label_format)
if io.exists(lable_path):
find_label_path = True
self.label_files.append(lable_path)
break
if not find_label_path:
logging.warning(
'Not find label file %s for img: %s, skip the sample!' %
(lable_path, img_path))
self.img_files.remove(img_path)

assert len(self.img_files) == len(self.label_files)
assert len(
self.img_files) > 0, 'No samples found in %s' % self.img_root

return list(zip(self.img_files, self.label_files))
1 change: 1 addition & 0 deletions easycv/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .resnet_jit import ResNetJIT
from .resnext import ResNeXt
from .shuffle_transformer import ShuffleTransformer
from .stdc import STDCContextPathNet, STDCNet
from .swin_transformer import SwinTransformer
from .swin_transformer3d import SwinTransformer3D
from .vision_transformer import VisionTransformer
Expand Down
Loading

0 comments on commit 26cd12a

Please sign in to comment.