diff --git a/configs/stdcseg/README.md b/configs/stdcseg/README.md new file mode 100644 index 0000000000..ed66b18f75 --- /dev/null +++ b/configs/stdcseg/README.md @@ -0,0 +1,14 @@ +# Rethinking BiSeNet For Real-time Semantic Segmentation + +## Reference + +> Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021. + + +## Performance + +### CityScapes + +| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links | +|---|---|---|---|---|---|---|---| +|STDC2-Seg50|STDC1446|1024x512|80000|74.62%|-|-|[backbone提取码:tss7](https://pan.baidu.com/s/16kh3aHTBBX6wfKiIG-y3yA) [model+log提取码:nchx](https://pan.baidu.com/s/1sFHqZWhcl8hFzGCrXu_c7Q) [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=30a6031fcc7cc09db93b4d33eb21724a) | diff --git a/configs/stdcseg/stdc2_seg_cityscapes_1024x512_80k.yml b/configs/stdcseg/stdc2_seg_cityscapes_1024x512_80k.yml new file mode 100644 index 0000000000..00c7d3189c --- /dev/null +++ b/configs/stdcseg/stdc2_seg_cityscapes_1024x512_80k.yml @@ -0,0 +1,60 @@ +_base_: '../_base_/cityscapes.yml' + +batch_size: 36 +iters: 80000 + +model: + type: STDCSeg + backbone: + type: STDC2 + pretrained: '/home/path/STDCNet1446_76.47.pdiparams' + num_classes: 19 + pretrained: null + +train_dataset: + type: Cityscapes + dataset_root: data/cityscapes + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.125 + max_scale_factor: 1.5 + scale_step_size: 0.125 + - type: RandomPaddingCrop + crop_size: [1024, 512] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.5 + contrast_range: 0.5 + saturation_range: 0.5 + - type: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + mode: train + +val_dataset: + type: Cityscapes + dataset_root: data/cityscapes + transforms: + - type: Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + mode: val + +optimizer: + type: sgd + momentum: 0.9 + weight_decay: 4.0e-5 + +loss: + types: + - type: OhemCrossEntropyLoss + - type: OhemCrossEntropyLoss + - type: OhemCrossEntropyLoss + - type: DetailAggregateLoss + coef: [1, 1, 1, 1] + +lr_scheduler: + type: PolynomialDecay + learning_rate: 0.01 + end_lr: 0 + power: 0.9 \ No newline at end of file diff --git a/paddleseg/models/__init__.py b/paddleseg/models/__init__.py index 128a29a25d..1a5c4a58b2 100644 --- a/paddleseg/models/__init__.py +++ b/paddleseg/models/__init__.py @@ -40,4 +40,5 @@ from .ppseg_lite import * from .mla_transformer import MLATransformer from .portraitnet import PortraitNet +from .stdcseg import STDCSeg from .segformer import SegFormer diff --git a/paddleseg/models/backbones/__init__.py b/paddleseg/models/backbones/__init__.py index ead5a30d7c..78bd55ba23 100644 --- a/paddleseg/models/backbones/__init__.py +++ b/paddleseg/models/backbones/__init__.py @@ -20,3 +20,4 @@ from .swin_transformer import * from .mobilenetv2 import * from .mix_transformer import * +from .stdcnet import * \ No newline at end of file diff --git a/paddleseg/models/backbones/stdcnet.py b/paddleseg/models/backbones/stdcnet.py new file mode 100644 index 0000000000..80dbdc1b63 --- /dev/null +++ b/paddleseg/models/backbones/stdcnet.py @@ -0,0 +1,241 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +import paddle.nn as nn + +from paddleseg.utils import utils +from paddleseg.cvlibs import manager,param_init +from paddleseg.models.layers.layer_libs import SyncBatchNorm + +__all__ = ["STDC1", "STDC2"] + + +class STDCNet(nn.Layer): + """ + The STDCNet implementation based on PaddlePaddle. + + The original article refers to Meituan + Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation." + (https://arxiv.org/abs/2104.13188) + + Args: + base(int, optional): base channels. Default: 64. + layers(list, optional): layers numbers list. It determines STDC block numbers of STDCNet's stage3\4\5. Defualt: [4, 5, 3]. + block_num(int,optional): block_num of features block. Default: 4. + type(str,optional): feature fusion method "cat"/"add". Default: "cat". + num_classes(int, optional): class number for image classification. Default: 1000. + dropout(float,optional): dropout ratio. if >0,use dropout ratio. Default: 0.20. + use_conv_last(bool,optional): whether to use the last ConvBNReLU layer . Default: False. + pretrained(str, optional): the path of pretrained model. + """ + + def __init__(self, base=64, + layers=[4, 5, 3], + block_num=4, + type="cat", + num_classes=1000, + dropout=0.20, + use_conv_last=False, + pretrained=None): + super(STDCNet, self).__init__() + if type == "cat": + block = CatBottleneck + elif type == "add": + block = AddBottleneck + self.use_conv_last = use_conv_last + self.features = self._make_layers(base, layers, block_num, block) + self.conv_last = ConvBNRelu(base * 16, max(1024, base * 16), 1, 1) + self.gap = nn.AdaptiveAvgPool2D(1) + self.fc = nn.Linear(max(1024, base * 16), max(1024, base * 16),bias_attr=None) + self.bn = nn.BatchNorm1D(max(1024, base * 16)) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(p=dropout) + self.linear = nn.Linear(max(1024, base * 16), num_classes, bias_attr=None) + + if(layers==[4,5,3]): #stdc1446 + self.x2 = nn.Sequential(self.features[:1]) + self.x4 = nn.Sequential(self.features[1:2]) + self.x8 = nn.Sequential(self.features[2:6]) + self.x16 = nn.Sequential(self.features[6:11]) + self.x32 = nn.Sequential(self.features[11:]) + elif(layers==[2,2,2]):#stdc813 + self.x2 = nn.Sequential(self.features[:1]) + self.x4 = nn.Sequential(self.features[1:2]) + self.x8 = nn.Sequential(self.features[2:4]) + self.x16 = nn.Sequential(self.features[4:6]) + self.x32 = nn.Sequential(self.features[6:]) + else: + raise NotImplementedError("model with layers:{} is not implemented!".format(layers)) + + self.pretrained = pretrained + self.init_weight() + + def forward(self, x): + """ + forward function for feature extract. + """ + feat2 = self.x2(x) + feat4 = self.x4(feat2) + feat8 = self.x8(feat4) + feat16 = self.x16(feat8) + feat32 = self.x32(feat16) + if self.use_conv_last: + feat32 = self.conv_last(feat32) + return feat2, feat4, feat8, feat16, feat32 + + def _make_layers(self, base, layers, block_num, block): + features = [] + features += [ConvBNRelu(3, base // 2, 3, 2)] + features += [ConvBNRelu(base // 2, base, 3, 2)] + + for i, layer in enumerate(layers): + for j in range(layer): + if i == 0 and j == 0: + features.append(block(base, base * 4, block_num, 2)) + elif j == 0: + features.append(block(base * int(math.pow(2, i + 1)), base * int(math.pow(2, i + 2)), block_num, 2)) + else: + features.append(block(base * int(math.pow(2, i + 2)), base * int(math.pow(2, i + 2)), block_num, 1)) + + return nn.Sequential(*features) + + def init_weight(self): + for layer in self.sublayers(): + if isinstance(layer, nn.Conv2D): + param_init.normal_init(layer.weight, std=0.001) + elif isinstance(layer, (nn.BatchNorm, nn.SyncBatchNorm)): + param_init.constant_init(layer.weight, value=1.0) + param_init.constant_init(layer.bias, value=0.0) + if self.pretrained is not None: + utils.load_pretrained_model(self, self.pretrained) + + +class ConvBNRelu(nn.Layer): + def __init__(self, in_planes, out_planes, kernel=3, stride=1): + super(ConvBNRelu, self).__init__() + self.conv = nn.Conv2D(in_planes, out_planes, kernel_size=kernel, stride=stride, padding=kernel // 2,bias_attr=None) + self.bn = SyncBatchNorm(out_planes,data_format='NCHW') + self.relu = nn.ReLU() + + def forward(self, x): + out = self.relu(self.bn(self.conv(x))) + return out + + +class AddBottleneck(nn.Layer): + def __init__(self, in_planes, out_planes, block_num=3, stride=1): + super(AddBottleneck, self).__init__() + assert block_num > 1, print("block number should be larger than 1.") + self.conv_list = nn.LayerList() + self.stride = stride + if stride == 2: + self.avd_layer = nn.Sequential( + nn.Conv2D(out_planes // 2, out_planes // 2, kernel_size=3, stride=2, padding=1, groups=out_planes // 2,bias_attr=None), + nn.BatchNorm2D(out_planes // 2), + ) + self.skip = nn.Sequential( + nn.Conv2D(in_planes, in_planes, kernel_size=3, stride=2, padding=1, groups=in_planes,bias_attr=None), + nn.BatchNorm2D(in_planes), + nn.Conv2D(in_planes, out_planes, kernel_size=1,bias_attr=None), + nn.BatchNorm2D(out_planes), + ) + stride = 1 + + for idx in range(block_num): + if idx == 0: + self.conv_list.append(ConvBNRelu(in_planes, out_planes // 2, kernel=1)) + elif idx == 1 and block_num == 2: + self.conv_list.append(ConvBNRelu(out_planes // 2, out_planes // 2, stride=stride)) + elif idx == 1 and block_num > 2: + self.conv_list.append(ConvBNRelu(out_planes // 2, out_planes // 4, stride=stride)) + elif idx < block_num - 1: + self.conv_list.append( + ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1)))) + else: + self.conv_list.append(ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx)))) + + def forward(self, x): + out_list = [] + out = x + for idx, conv in enumerate(self.conv_list): + if idx == 0 and self.stride == 2: + out = self.avd_layer(conv(out)) + else: + out = conv(out) + out_list.append(out) + if self.stride == 2: + x = self.skip(x) + return paddle.concat(out_list, axis=1) + x + + +class CatBottleneck(nn.Layer): + def __init__(self, in_planes, out_planes, block_num=3, stride=1): + super(CatBottleneck, self).__init__() + assert block_num > 1, print("block number should be larger than 1.") + self.conv_list = nn.LayerList() + self.stride = stride + if stride == 2: + self.avd_layer = nn.Sequential( + nn.Conv2D(out_planes // 2, out_planes // 2, kernel_size=3, stride=2, padding=1, groups=out_planes // 2,bias_attr=None + ), + nn.BatchNorm2D(out_planes // 2), + ) + self.skip = nn.AvgPool2D(kernel_size=3, stride=2, padding=1) + stride = 1 + + for idx in range(block_num): + if idx == 0: + self.conv_list.append(ConvBNRelu(in_planes, out_planes // 2, kernel=1)) + elif idx == 1 and block_num == 2: + self.conv_list.append(ConvBNRelu(out_planes // 2, out_planes // 2, stride=stride)) + elif idx == 1 and block_num > 2: + self.conv_list.append(ConvBNRelu(out_planes // 2, out_planes // 4, stride=stride)) + elif idx < block_num - 1: + self.conv_list.append( + ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx + 1)))) + else: + self.conv_list.append(ConvBNRelu(out_planes // int(math.pow(2, idx)), out_planes // int(math.pow(2, idx)))) + + def forward(self, x): + out_list = [] + out1 = self.conv_list[0](x) + for idx, conv in enumerate(self.conv_list[1:]): + if idx == 0: + if self.stride == 2: + out = conv(self.avd_layer(out1)) + else: + out = conv(out1) + else: + out = conv(out) + out_list.append(out) + + if self.stride == 2: + out1 = self.skip(out1) + out_list.insert(0, out1) + out = paddle.concat(out_list, axis=1) + return out + + +@manager.BACKBONES.add_component +def STDC2(**kwargs): + model = STDCNet(base=64,layers=[4,5,3],**kwargs) + return model + +@manager.BACKBONES.add_component +def STDC1(**kwargs): + model = STDCNet(base=64,layers=[2,2,2],**kwargs) + return model \ No newline at end of file diff --git a/paddleseg/models/losses/__init__.py b/paddleseg/models/losses/__init__.py index 027241e1e4..f5a949cef8 100644 --- a/paddleseg/models/losses/__init__.py +++ b/paddleseg/models/losses/__init__.py @@ -29,3 +29,4 @@ from .focal_loss import FocalLoss from .kl_loss import KLLoss from .rmi_loss import RMILoss +from .detail_aggregate_loss import DetailAggregateLoss \ No newline at end of file diff --git a/paddleseg/models/losses/detail_aggregate_loss.py b/paddleseg/models/losses/detail_aggregate_loss.py new file mode 100644 index 0000000000..22a73f0a3f --- /dev/null +++ b/paddleseg/models/losses/detail_aggregate_loss.py @@ -0,0 +1,116 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager + + +@manager.LOSSES.add_component +class DetailAggregateLoss(nn.Layer): + """ + DetailAggregateLoss's implementation based on PaddlePaddle. + + The original article refers to Meituan + Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation." + (https://arxiv.org/abs/2104.13188) + + Args: + ignore_index (int64, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. Default ``255``. + + """ + + def __init__(self, ignore_index=255): + super(DetailAggregateLoss, self).__init__() + self.ignore_index = ignore_index + self.laplacian_kernel = paddle.to_tensor([-1, -1, -1, -1, 8, -1, -1, -1, -1], dtype='float32').reshape( + (1, 1, 3, 3)) + self.fuse_kernel = paddle.create_parameter([1, 3, 1, 1], dtype='float32') + + def forward(self, logits, label): + """ + Args: + logits (Tensor): Logit tensor, the data type is float32, float64. Shape is + (N, C), where C is number of classes, and if shape is more than 2D, this + is (N, C, D1, D2,..., Dk), k >= 1. + label (Tensor): Label tensor, the data type is int64. Shape is (N), where each + value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is + (N, D1, D2,..., Dk), k >= 1. + Returns: loss + """ + boundary_targets = F.conv2d(paddle.unsqueeze(label, axis=1).astype('float32'), self.laplacian_kernel, + padding=1) + boundary_targets = paddle.clip(boundary_targets, min=0) + boundary_targets = boundary_targets > 0.1 + boundary_targets = boundary_targets.astype('float32') + + boundary_targets_x2 = F.conv2d(paddle.unsqueeze(label, axis=1).astype('float32'), self.laplacian_kernel, + stride=2, padding=1) + boundary_targets_x2 = paddle.clip(boundary_targets_x2, min=0) + boundary_targets_x4 = F.conv2d(paddle.unsqueeze(label, axis=1).astype('float32'), self.laplacian_kernel, + stride=4, padding=1) + boundary_targets_x4 = paddle.clip(boundary_targets_x4, min=0) + + boundary_targets_x8 = F.conv2d(paddle.unsqueeze(label, axis=1).astype('float32'), self.laplacian_kernel, + stride=8, padding=1) + boundary_targets_x8 = paddle.clip(boundary_targets_x8, min=0) + + boundary_targets_x8_up = F.interpolate(boundary_targets_x8, boundary_targets.shape[2:], mode='nearest') + boundary_targets_x4_up = F.interpolate(boundary_targets_x4, boundary_targets.shape[2:], mode='nearest') + boundary_targets_x2_up = F.interpolate(boundary_targets_x2, boundary_targets.shape[2:], mode='nearest') + + boundary_targets_x2_up = boundary_targets_x2_up > 0.1 + boundary_targets_x2_up = boundary_targets_x2_up.astype('float32') + + boundary_targets_x4_up = boundary_targets_x4_up > 0.1 + boundary_targets_x4_up = boundary_targets_x4_up.astype('float32') + + boundary_targets_x8_up = boundary_targets_x8_up > 0.1 + boundary_targets_x8_up = boundary_targets_x8_up.astype('float32') + + boudary_targets_pyramids = paddle.stack((boundary_targets, boundary_targets_x2_up, boundary_targets_x4_up), + axis=1) + + boudary_targets_pyramids = paddle.squeeze(boudary_targets_pyramids, axis=2) + boudary_targets_pyramid = F.conv2d(boudary_targets_pyramids, self.fuse_kernel) + + boudary_targets_pyramid = boudary_targets_pyramid > 0.1 + boudary_targets_pyramid = boudary_targets_pyramid.astype('float32') + + if logits.shape[-1] != boundary_targets.shape[-1]: + logits = F.interpolate( + logits, boundary_targets.shape[2:], mode='bilinear', align_corners=True) + + bce_loss = F.binary_cross_entropy_with_logits(logits, boudary_targets_pyramid) + dice_loss = self.fixed_dice_loss_func(F.sigmoid(logits), boudary_targets_pyramid) + detail_loss = bce_loss + dice_loss + + label.stop_gradient = True + return detail_loss + + def fixed_dice_loss_func(self, input, target): + """ + simplified diceloss for DetailAggregateLoss. + """ + smooth = 1. + n = input.shape[0] + iflat = paddle.reshape(input, [n, -1]) + tflat = paddle.reshape(target, [n, -1]) + intersection = paddle.sum((iflat * tflat), axis=1) + loss = 1 - ((2. * intersection + smooth) / + (paddle.sum(iflat, axis=1) + paddle.sum(tflat, axis=1) + smooth)) + return paddle.mean(loss) \ No newline at end of file diff --git a/paddleseg/models/stdcseg.py b/paddleseg/models/stdcseg.py new file mode 100644 index 0000000000..8860beea0b --- /dev/null +++ b/paddleseg/models/stdcseg.py @@ -0,0 +1,193 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg import utils +from paddleseg.models import layers +from paddleseg.cvlibs import manager +from paddleseg.utils import utils + + +@manager.MODELS.add_component +class STDCSeg(nn.Layer): + """ + The STDCSeg implementation based on PaddlePaddle. + + The original article refers to Meituan + Fan, Mingyuan, et al. "Rethinking BiSeNet For Real-time Semantic Segmentation." + (https://arxiv.org/abs/2104.13188) + + Args: + num_classes(int,optional): The unique number of target classes. + backbone(nn.Layer): Backbone network, STDCNet1446/STDCNet813. STDCNet1446->STDC2,STDCNet813->STDC813. + use_boundary_8(bool,non-optional): Whether to use detail loss. it should be True accroding to paper for best metric. Default: True. + Actually,if you want to use _boundary_2/_boundary_4/_boundary_16,you should append loss function number of DetailAggregateLoss.It should work properly. + use_conv_last(bool,optional): Determine ContextPath 's inplanes variable according to whether to use bockbone's last conv. Default: False. + pretrained (str, optional): The path or url of pretrained model. Default: None. + """ + def __init__(self, num_classes, backbone, use_boundary_2=False, use_boundary_4=False, + use_boundary_8=True, use_boundary_16=False, use_conv_last=False, pretrained=None): + super(STDCSeg, self).__init__() + + self.use_boundary_2 = use_boundary_2 + self.use_boundary_4 = use_boundary_4 + self.use_boundary_8 = use_boundary_8 + self.use_boundary_16 = use_boundary_16 + self.cp = ContextPath(backbone, use_conv_last=use_conv_last) + self.ffm = FeatureFusionModule(384, 256) + self.conv_out = BiSeNetOutput(256, 256, num_classes) + self.conv_out16 = BiSeNetOutput(128, 64, num_classes) + self.conv_out32 = BiSeNetOutput(128, 64, num_classes) + self.conv_out_sp16 = BiSeNetOutput(512, 64, 1) + self.conv_out_sp8 = BiSeNetOutput(256, 64, 1) + self.conv_out_sp4 = BiSeNetOutput(64, 64, 1) + self.conv_out_sp2 = BiSeNetOutput(32, 64, 1) + self.pretrained = pretrained + self.init_weight() + + def forward(self, x): + H, W = x.shape[2:] + feat_res2, feat_res4, feat_res8, feat_res16, feat_cp8, feat_cp16 = self.cp(x) + + if self.training: + feat_out_sp2 = self.conv_out_sp2(feat_res2) + feat_out_sp4 = self.conv_out_sp4(feat_res4) + feat_out_sp8 = self.conv_out_sp8(feat_res8) + feat_out_sp16 = self.conv_out_sp16(feat_res16) + feat_fuse = self.ffm(feat_res8, feat_cp8) + feat_out = self.conv_out(feat_fuse) + feat_out16 = self.conv_out16(feat_cp8) + feat_out32 = self.conv_out32(feat_cp16) + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True) + feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True) + + if self.use_boundary_2 and self.use_boundary_4 and self.use_boundary_8: + return feat_out, feat_out16, feat_out32, feat_out_sp2, feat_out_sp4, feat_out_sp8 + + if (not self.use_boundary_2) and self.use_boundary_4 and self.use_boundary_8: + return feat_out, feat_out16, feat_out32, feat_out_sp4, feat_out_sp8 + + if (not self.use_boundary_2) and (not self.use_boundary_4) and self.use_boundary_8: + return feat_out, feat_out16, feat_out32, feat_out_sp8 + + if (not self.use_boundary_2) and (not self.use_boundary_4) and (not self.use_boundary_8): + return feat_out, feat_out16, feat_out32 + else: + feat_fuse = self.ffm(feat_res8, feat_cp8) + feat_out = self.conv_out(feat_fuse) + feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True) + return [feat_out] + + def init_weight(self): + if self.pretrained is not None: + utils.load_entire_model(self, self.pretrained) + + +class BiSeNetOutput(nn.Layer): + def __init__(self, in_chan, mid_chan, n_classes): + super(BiSeNetOutput, self).__init__() + self.conv = layers.ConvBNReLU(in_chan, mid_chan, kernel_size=3, stride=1, padding=1) + self.conv_out = nn.Conv2D(mid_chan, n_classes, kernel_size=1, bias_attr=None) + + def forward(self, x): + x = self.conv(x) + x = self.conv_out(x) + return x + + +class AttentionRefinementModule(nn.Layer): + def __init__(self, in_chan, out_chan): + super(AttentionRefinementModule, self).__init__() + self.conv = layers.ConvBNReLU(in_chan, out_chan, kernel_size=3, stride=1, padding=1) + self.conv_atten = nn.Conv2D(out_chan, out_chan, kernel_size=1, bias_attr=None) + self.bn_atten = nn.BatchNorm2D(out_chan) + self.sigmoid_atten = nn.Sigmoid() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.shape[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = paddle.multiply(feat,atten) + return out + + +class ContextPath(nn.Layer): + def __init__(self, backbone, use_conv_last=False): + super(ContextPath, self).__init__() + self.backbone = backbone + self.arm16 = AttentionRefinementModule(512, 128) + inplanes = 1024 + if use_conv_last: + inplanes = 1024 + self.arm32 = AttentionRefinementModule(inplanes, 128) + self.conv_head32 = layers.ConvBNReLU(128, 128, kernel_size=3, stride=1, padding=1) + self.conv_head16 = layers.ConvBNReLU(128, 128, kernel_size=3, stride=1, padding=1) + self.conv_avg = layers.ConvBNReLU(inplanes, 128, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + feat2, feat4, feat8, feat16, feat32 = self.backbone(x) + H8, W8 = feat8.shape[2:] + H16, W16 = feat16.shape[2:] + H32, W32 = feat32.shape[2:] + avg = F.avg_pool2d(feat32, feat32.shape[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (H32, W32), mode='nearest') + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + return feat2, feat4, feat8, feat16, feat16_up, feat32_up # x8, x16 + + +class FeatureFusionModule(nn.Layer): + def __init__(self, in_chan, out_chan): + super(FeatureFusionModule, self).__init__() + self.convblk = layers.ConvBNReLU(in_chan, out_chan, kernel_size=1, stride=1, padding=0) + self.conv1 = nn.Conv2D(out_chan, + out_chan // 4, + kernel_size=1, + stride=1, + padding=0, + bias_attr=None) + self.conv2 = nn.Conv2D(out_chan // 4, + out_chan, + kernel_size=1, + stride=1, + padding=0, + bias_attr=None) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, fsp, fcp): + fcat = paddle.concat([fsp, fcp], axis=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.shape[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = paddle.multiply(feat, atten) + feat_out = feat_atten + feat + return feat_out +