-
Notifications
You must be signed in to change notification settings - Fork 206
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add stdc semantic segmentation algorithm
- Loading branch information
Showing
24 changed files
with
1,350 additions
and
8 deletions.
There are no files selected for viewing
198 changes: 198 additions & 0 deletions
198
configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
_base_ = ['configs/base.py'] | ||
|
||
# warning batch_size need >= 2 | ||
# model | ||
norm_cfg = dict(type='BN', requires_grad=True) | ||
model = dict( | ||
type='EncoderDecoder', | ||
backbone=dict( | ||
type='STDCContextPathNet', | ||
backbone_cfg=dict( | ||
type='STDCNet', | ||
stdc_type='STDCNet1', | ||
in_channels=3, | ||
channels=(32, 64, 256, 512, 1024), | ||
bottleneck_type='cat', | ||
num_convs=4, | ||
norm_cfg=norm_cfg, | ||
act_cfg=dict(type='ReLU'), | ||
with_final_conv=False), | ||
last_in_channels=(1024, 512), | ||
out_channels=128, | ||
ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4)), | ||
decode_head=dict( | ||
type='FCNHead', | ||
in_channels=256, | ||
channels=256, | ||
num_convs=1, | ||
num_classes=19, | ||
in_index=3, | ||
concat_input=False, | ||
dropout_ratio=0.1, | ||
norm_cfg=norm_cfg, | ||
align_corners=True, | ||
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), | ||
auxiliary_head=[ | ||
dict( | ||
type='FCNHead', | ||
in_channels=128, | ||
channels=64, | ||
num_convs=1, | ||
num_classes=19, | ||
in_index=2, | ||
norm_cfg=norm_cfg, | ||
concat_input=False, | ||
align_corners=False, | ||
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), | ||
dict( | ||
type='FCNHead', | ||
in_channels=128, | ||
channels=64, | ||
num_convs=1, | ||
num_classes=19, | ||
in_index=1, | ||
norm_cfg=norm_cfg, | ||
concat_input=False, | ||
align_corners=False, | ||
sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), | ||
dict( | ||
type='STDCHead', | ||
in_channels=256, | ||
channels=64, | ||
num_convs=1, | ||
num_classes=2, | ||
boundary_threshold=0.1, | ||
in_index=0, | ||
norm_cfg=norm_cfg, | ||
concat_input=False, | ||
align_corners=True, | ||
loss_decode=[ | ||
dict( | ||
type='CrossEntropyLoss', | ||
loss_name='loss_ce', | ||
use_sigmoid=True, | ||
loss_weight=1.0), | ||
dict(type='DiceLoss', loss_name='loss_dice', loss_weight=1.0) | ||
]), | ||
], | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole'), | ||
pretrained= | ||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc1_easycv.pth' | ||
) | ||
|
||
# dataset | ||
CLASSES = [ | ||
'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', | ||
'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', | ||
'truck', 'bus', 'train', 'motorcycle', 'bicycle' | ||
] | ||
img_norm_cfg = dict( | ||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) | ||
crop_size = (512, 1024) | ||
|
||
train_pipeline = [ | ||
dict(type='MMResize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), | ||
dict(type='SegRandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='MMRandomFlip', flip_ratio=0.5), | ||
dict(type='MMPhotoMetricDistortion'), | ||
dict(type='MMNormalize', **img_norm_cfg), | ||
dict(type='MMPad', size=crop_size), | ||
dict(type='DefaultFormatBundle'), | ||
dict( | ||
type='Collect', | ||
keys=['img', 'gt_semantic_seg'], | ||
meta_keys=('filename', 'ori_filename', 'ori_shape', 'img_shape', | ||
'pad_shape', 'scale_factor', 'flip', 'flip_direction', | ||
'img_norm_cfg')), | ||
] | ||
|
||
test_pipeline = [ | ||
dict( | ||
type='MMMultiScaleFlipAug', | ||
img_scale=(2048, 1024), | ||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], | ||
flip=False, | ||
transforms=[ | ||
dict(type='MMResize', keep_ratio=True), | ||
dict(type='MMRandomFlip'), | ||
dict(type='MMNormalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict( | ||
type='Collect', | ||
keys=['img'], | ||
meta_keys=('filename', 'ori_filename', 'ori_shape', | ||
'img_shape', 'pad_shape', 'scale_factor', 'flip', | ||
'flip_direction', 'img_norm_cfg')), | ||
]) | ||
] | ||
dataset_type = 'SegDataset' | ||
data_root = '../Cityscapes/' | ||
|
||
train_img_root = data_root + 'leftImg8bit/train/' | ||
train_label_root = data_root + 'gtFine/train/' | ||
|
||
val_img_root = data_root + 'leftImg8bit/val/' | ||
val_label_root = data_root + 'gtFine/val/' | ||
data = dict( | ||
imgs_per_gpu=6, | ||
workers_per_gpu=4, | ||
persistent_workers=True, | ||
train=dict( | ||
type=dataset_type, | ||
ignore_index=255, | ||
data_source=dict( | ||
type='SegSourceCityscapes', | ||
img_root=train_img_root, | ||
label_root=train_label_root, | ||
classes=CLASSES), | ||
pipeline=train_pipeline), | ||
val=dict( | ||
imgs_per_gpu=1, | ||
ignore_index=255, | ||
type=dataset_type, | ||
data_source=dict( | ||
type='SegSourceCityscapes', | ||
img_root=val_img_root, | ||
label_root=val_label_root, | ||
classes=CLASSES), | ||
pipeline=test_pipeline)) | ||
|
||
# optimizer | ||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) | ||
optimizer_config = dict() | ||
# learning policy | ||
lr_config = dict( | ||
policy='CosineAnnealing', | ||
min_lr=1e-4, | ||
warmup='linear', | ||
warmup_iters=10, | ||
warmup_ratio=0.0001, | ||
warmup_by_epoch=True, | ||
by_epoch=False) | ||
|
||
# runtime settings | ||
total_epochs = 1290 | ||
checkpoint_config = dict(interval=10) | ||
eval_config = dict(interval=10, gpu_collect=False) | ||
eval_pipelines = [ | ||
dict( | ||
mode='test', | ||
evaluators=[ | ||
dict( | ||
type='SegmentationEvaluator', | ||
classes=CLASSES, | ||
metric_names=['mIoU']) | ||
], | ||
) | ||
] | ||
|
||
# export config | ||
export = dict(export_neck=True) | ||
checkpoint_sync_export = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
_base_ = ['configs/segmentation/stdc/stdc1_cityscape_8xb6_e1290.py'] | ||
|
||
model = dict( | ||
backbone=dict(backbone_cfg=dict(stdc_type='STDCNet2')), | ||
pretrained= | ||
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/EasyCV/modelzoo/segmentation/stdc/pretrain/stdc2_easycv.pth' | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
129 changes: 129 additions & 0 deletions
129
easycv/datasets/segmentation/data_sources/cityscapes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright (c) Alibaba, Inc. and its affiliates. | ||
import copy | ||
import logging | ||
import os | ||
import subprocess | ||
|
||
import numpy as np | ||
|
||
from easycv.datasets.registry import DATASOURCES | ||
from easycv.file import io | ||
from easycv.file.image import load_image as _load_img | ||
from .raw import SegSourceRaw | ||
|
||
try: | ||
import cityscapesscripts.helpers.labels as CSLabels | ||
except ModuleNotFoundError as e: | ||
res = subprocess.call('pip install cityscapesscripts', shell=True) | ||
if res != 0: | ||
info_string = ( | ||
'\n\nAuto install failed! Please install cityscapesscripts with the following commands :\n' | ||
'\t`pip install cityscapesscripts`\n') | ||
raise ModuleNotFoundError(info_string) | ||
|
||
|
||
def load_seg_map_cityscape(seg_path, reduce_zero_label): | ||
gt_semantic_seg = _load_img(seg_path, mode='P') | ||
gt_semantic_seg_copy = gt_semantic_seg.copy() | ||
for labels in CSLabels.labels: | ||
gt_semantic_seg_copy[gt_semantic_seg == labels.id] = labels.trainId | ||
|
||
return {'gt_semantic_seg': gt_semantic_seg_copy} | ||
|
||
|
||
@DATASOURCES.register_module | ||
class SegSourceCityscapes(SegSourceRaw): | ||
"""Cityscapes datasource | ||
""" | ||
CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', | ||
'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', | ||
'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', | ||
'bicycle') | ||
|
||
PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], | ||
[190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], | ||
[107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], | ||
[255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], | ||
[0, 80, 100], [0, 0, 230], [119, 11, 32]] | ||
|
||
def __init__(self, | ||
img_suffix='_leftImg8bit.png', | ||
label_suffix='_gtFine_labelIds.png', | ||
**kwargs): | ||
super(SegSourceCityscapes, self).__init__( | ||
img_suffix=img_suffix, label_suffix=label_suffix, **kwargs) | ||
|
||
def __getitem__(self, idx): | ||
result_dict = self.samples_list[idx] | ||
load_success = True | ||
try: | ||
# avoid data cache from taking up too much memory | ||
if not self.cache_at_init and not self.cache_on_the_fly: | ||
result_dict = copy.deepcopy(result_dict) | ||
|
||
if not self.cache_at_init: | ||
if result_dict.get('img', None) is None: | ||
img = _load_img(result_dict['filename'], mode='BGR') | ||
result = { | ||
'img': img.astype(np.float32), | ||
'img_shape': img.shape, # h, w, c | ||
'ori_shape': img.shape, | ||
} | ||
result_dict.update(result) | ||
if result_dict.get('gt_semantic_seg', None) is None: | ||
result_dict.update( | ||
load_seg_map_cityscape( | ||
result_dict['seg_filename'], | ||
reduce_zero_label=self.reduce_zero_label)) | ||
if self.cache_on_the_fly: | ||
self.samples_list[idx] = result_dict | ||
result_dict = self.post_process_fn(copy.deepcopy(result_dict)) | ||
self._retry_count = 0 | ||
except Exception as e: | ||
logging.warning(e) | ||
load_success = False | ||
|
||
if not load_success: | ||
logging.warning( | ||
'Something wrong with current sample %s,Try load next sample...' | ||
% result_dict.get('filename', '')) | ||
self._retry_count += 1 | ||
if self._retry_count >= self._max_retry_num: | ||
raise ValueError('All samples failed to load!') | ||
|
||
result_dict = self[(idx + 1) % self.num_samples] | ||
|
||
return result_dict | ||
|
||
def get_source_iterator(self): | ||
|
||
self.img_files = [ | ||
os.path.join(self.img_root, i) | ||
for i in io.listdir(self.img_root, recursive=True) | ||
if i.endswith(self.img_suffix[0]) | ||
] | ||
|
||
self.label_files = [] | ||
for img_path in self.img_files: | ||
self.img_root = os.path.join(self.img_root, '') | ||
img_name = img_path.replace(self.img_root, | ||
'')[:-len(self.img_suffix[0])] | ||
find_label_path = False | ||
for label_format in self.label_suffix: | ||
lable_path = os.path.join(self.label_root, | ||
img_name + label_format) | ||
if io.exists(lable_path): | ||
find_label_path = True | ||
self.label_files.append(lable_path) | ||
break | ||
if not find_label_path: | ||
logging.warning( | ||
'Not find label file %s for img: %s, skip the sample!' % | ||
(lable_path, img_path)) | ||
self.img_files.remove(img_path) | ||
|
||
assert len(self.img_files) == len(self.label_files) | ||
assert len( | ||
self.img_files) > 0, 'No samples found in %s' % self.img_root | ||
|
||
return list(zip(self.img_files, self.label_files)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.