diff --git a/README.md b/README.md index 6786613ce4..9e08c8611e 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ python train.py --config configs/quick_start/bisenet_optic_disc_512x512_1k.yml * [API Tutorial](https://aistudio.baidu.com/aistudio/projectdetail/1339458) * [Data Preparation](./docs/data_prepare.md) * [Training Configuration](./configs/) +* [Loss Usage](./docs/loss_usage.md) * [API References](./docs/apis) * [Add New Components](./docs/add_new_model.md) @@ -111,7 +112,7 @@ If you find our project useful in your research, please consider citing: ```latex @misc{liu2021paddleseg, - title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation}, + title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation}, author={Yi Liu and Lutao Chu and Guowei Chen and Zewu Wu and Zeyu Chen and Baohua Lai and Yuying Hao}, year={2021}, eprint={2101.06175}, diff --git a/README_CN.md b/README_CN.md index 949ddbf503..0b6bdeefff 100644 --- a/README_CN.md +++ b/README_CN.md @@ -95,6 +95,7 @@ python train.py --config configs/quick_start/bisenet_optic_disc_512x512_1k.yml * [API使用教程](https://aistudio.baidu.com/aistudio/projectdetail/1339458) * [数据集准备](./docs/data_prepare.md) * [配置项](./configs/) +* [Loss使用](./docs/loss_usage.md) * [API参考](./docs/apis) * [添加新组件](./docs/add_new_model.md) @@ -114,7 +115,7 @@ python train.py --config configs/quick_start/bisenet_optic_disc_512x512_1k.yml ```latex @misc{liu2021paddleseg, - title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation}, + title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation}, author={Yi Liu and Lutao Chu and Guowei Chen and Zewu Wu and Zeyu Chen and Baohua Lai and Yuying Hao}, year={2021}, eprint={2101.06175}, diff --git a/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k_lovasz_softmax.yml b/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k_lovasz_softmax.yml new file mode 100644 index 0000000000..8bbd59e6ac --- /dev/null +++ b/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k_lovasz_softmax.yml @@ -0,0 +1,35 @@ +_base_: '../_base_/cityscapes.yml' + +batch_size: 2 +iters: 160000 + +model: + type: OCRNet + backbone: + type: HRNet_W18 + pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz + backbone_indices: [0] + +optimizer: + type: sgd + +learning_rate: + value: 0.01 + decay: + type: poly + power: 0.9 + + +loss: + types: + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszSoftmaxLoss + coef: [0.8, 0.2] + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszSoftmaxLoss + coef: [0.8, 0.2] + coef: [1, 0.4] diff --git a/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k.yml b/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k.yml new file mode 100644 index 0000000000..d1366a43c1 --- /dev/null +++ b/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k.yml @@ -0,0 +1,48 @@ + + +batch_size: 4 +iters: 15000 + +train_dataset: + type: MiniDeepGlobeRoadExtraction + dataset_root: data/MiniDeepGlobeRoadExtraction + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [768, 768] + - type: RandomHorizontalFlip + - type: Normalize + mode: train + +val_dataset: + type: MiniDeepGlobeRoadExtraction + dataset_root: data/MiniDeepGlobeRoadExtraction + transforms: + - type: Normalize + mode: val + +model: + type: OCRNet + backbone: + type: HRNet_W18 + pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz + backbone_indices: [0] + +optimizer: + type: sgd + +learning_rate: + value: 0.01 + decay: + type: poly + power: 0.9 + + +loss: + types: + - type: CrossEntropyLoss + - type: CrossEntropyLoss + coef: [1, 0.4] diff --git a/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k_lovasz_hinge.yml b/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k_lovasz_hinge.yml new file mode 100644 index 0000000000..72f23e993a --- /dev/null +++ b/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k_lovasz_hinge.yml @@ -0,0 +1,55 @@ + + +batch_size: 4 +iters: 15000 + +train_dataset: + type: MiniDeepGlobeRoadExtraction + dataset_root: data/MiniDeepGlobeRoadExtraction + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [768, 768] + - type: RandomHorizontalFlip + - type: Normalize + mode: train + +val_dataset: + type: MiniDeepGlobeRoadExtraction + dataset_root: data/MiniDeepGlobeRoadExtraction + transforms: + - type: Normalize + mode: val + +model: + type: OCRNet + backbone: + type: HRNet_W18 + pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz + backbone_indices: [0] + +optimizer: + type: sgd + +learning_rate: + value: 0.01 + decay: + type: poly + power: 0.9 + +loss: + types: + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszHingeLoss + coef: [1, 0.01] + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszHingeLoss + coef: [1, 0.01] + coef: [1, 0.4] diff --git a/docs/apis/README.md b/docs/apis/README.md index e80296af67..ec797bdf8d 100644 --- a/docs/apis/README.md +++ b/docs/apis/README.md @@ -5,4 +5,5 @@ * [paddleseg.datasets](./datasets.md) * [paddleseg.models](./models.md) * [paddleseg.models.backbones](./backbones.md) +* [paddleseg.models.losses](./losses.md) * [paddleseg.transforms](./transforms.md) diff --git a/docs/apis/losses.md b/docs/apis/losses.md new file mode 100644 index 0000000000..7f799a8384 --- /dev/null +++ b/docs/apis/losses.md @@ -0,0 +1,33 @@ +# [paddleseg.models.losses](../../paddleseg/models/losses) + +## LovaszSoftmaxLoss +> CLASS paddleseg.models.losses.LovaszSoftmaxLoss(ignore_index=255, classes='present') + + Multi-class Lovasz-Softmax loss. + +> > Args +> > > - **ignore_index** (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``. +> > > - **classes** (str|list): 'all' for all, 'present' for classes present in labels, or a list of classes to average. + + +## LovaszHingeLoss +> CLASS paddleseg.models.losses.LovaszHingeLoss(ignore_index=255) + + Binary Lovasz hinge loss. + +> > Args +> > > - **ignore_index** (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``. + + +## MixedLoss +> CLASS paddleseg.models.losses.MixedLoss(losses, coef) + + Weighted computations for multiple Loss. + The advantage is that mixed loss training can be achieved without changing the networking code. + +> > Args +> > > - **losses** (list of nn.Layer): A list consisting of multiple loss classes +> > > - **coef** (float|int): Weighting coefficient of multiple loss + +> > Returns +> > > - A callable object of MixedLoss. diff --git a/docs/images/Lovasz_Hinge_Evaluate_mIoU.png b/docs/images/Lovasz_Hinge_Evaluate_mIoU.png new file mode 100644 index 0000000000..62de580163 Binary files /dev/null and b/docs/images/Lovasz_Hinge_Evaluate_mIoU.png differ diff --git a/docs/images/Lovasz_Softmax_Evaluate_mIoU.png b/docs/images/Lovasz_Softmax_Evaluate_mIoU.png new file mode 100644 index 0000000000..8e55dd3856 Binary files /dev/null and b/docs/images/Lovasz_Softmax_Evaluate_mIoU.png differ diff --git a/docs/images/deepglobe.png b/docs/images/deepglobe.png new file mode 100644 index 0000000000..0ffc311b5f Binary files /dev/null and b/docs/images/deepglobe.png differ diff --git a/docs/loss_usage.md b/docs/loss_usage.md new file mode 100644 index 0000000000..b091133324 --- /dev/null +++ b/docs/loss_usage.md @@ -0,0 +1,4 @@ +# Loss usage + +- [Lovasz loss](lovasz_loss.md) +- To be continued diff --git a/docs/lovasz_loss.md b/docs/lovasz_loss.md new file mode 100644 index 0000000000..edf15afdfc --- /dev/null +++ b/docs/lovasz_loss.md @@ -0,0 +1,125 @@ +# Lovasz loss +对于图像分割任务中,经常出现类别分布不均匀的情况,例如:工业产品的瑕疵检测、道路提取及病变区域提取等。我们可使用lovasz loss解决这个问题。 + +Lovasz loss基于子模损失(submodular losses)的凸Lovasz扩展,对神经网络的mean IoU损失进行优化。Lovasz loss根据分割目标的类别数量可分为两种:lovasz hinge loss和lovasz softmax loss. 其中lovasz hinge loss适用于二分类问题,lovasz softmax loss适用于多分类问题。该工作发表在CVPR 2018上,可点击[参考文献](#参考文献)查看具体原理。 + + +## Lovasz loss使用指南 +接下来介绍如何使用lovasz loss进行训练。需要注意的是,通常的直接训练方式并一定管用,我们推荐另外2种训练方式: +- (1)与cross entropy loss或bce loss(binary cross-entropy loss)加权结合使用。 +- (2)先使用cross entropy loss或bce loss进行训练,再使用lovasz softmax loss或lovasz hinge loss进行finetuning. + +以方式(1)为例,通过`MixedLoss`类选择训练时的损失函数, 通过`coef`参数对不同loss进行权重配比,从而灵活地进行训练调参。如下所示: + +```yaml +loss: + types: + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszSoftmaxLoss + coef: [0.8, 0.2] +``` + +```yaml +loss: + types: + - type: MixedLoss + losses: + - type: CrossEntropyLoss + - type: LovaszHingeLoss + coef: [1, 0.02] +``` + + +## Lovasz softmax loss实验对比 + +接下来以经典的[Cityscapes](https://www.cityscapes-dataset.com/)数据集为例应用lovasz softmax loss. Cityscapes数据集共有19类目标,其中的类别并不均衡,例如类别`road`、`building`很常见,`fence`、`motocycle`、`wall`则较为罕见。我们将lovasz softmax loss与softmax loss进行了实验对比。这里使用OCRNet模型,backbone为HRNet w18. + + +* 数据准备 + +见[数据集准备教程](data_prepare.md) + +* Lovasz loss训练 +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \ +--config configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k_lovasz_softmax.yml \ +--use_vdl --num_workers 3 --do_eval +``` + +* Cross entropy loss训练 +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \ +--config configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k.yml \ +--use_vdl --num_workers 3 --do_eval +``` + +* 结果比较 + +实验mIoU曲线如下图所示。 +

+
+

+ + +|Loss|best mIoU| +|-|-| +|cross entropy loss|80.46%| +|lovasz softmax loss + cross entropy loss|81.53%| + +图中蓝色曲线代表lovasz softmax loss + cross entropy loss,绿色曲线代表cross entropy loss,相比提升1个百分点。 + +可看出使用lovasz softmax loss后,精度曲线基本都高于原来的精度。 + +## Lovasz hinge loss实验对比 + +我们以道路提取任务为例应用lovasz hinge loss. +基于MiniDeepGlobeRoadExtraction数据集与cross entropy loss进行了实验对比。 +该数据集来源于[DeepGlobe CVPR2018挑战赛](http://deepglobe.org/)的Road Extraction单项,训练数据道路占比为 4.5%. 道路在整张图片中的比例很小,是典型的类别不均衡场景。图片样例如下: +

+
+

+ +这里使用OCRNet模型,backbone为HRNet w18. + +* 数据集 +我们从DeepGlobe比赛的Road Extraction的训练集中随机抽取了800张图片作为训练集,200张图片作为验证集, +制作了一个小型的道路提取数据集[MiniDeepGlobeRoadExtraction](https://paddleseg.bj.bcebos.com/dataset/MiniDeepGlobeRoadExtraction.zip)。 +运行训练脚本将自动下载该数据集。 + +* Lovasz loss训练 +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \ +--config configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k_lovasz_hinge.yml \ +--use_vdl --num_workers 3 --do_eval +``` + +* Cross entropy loss训练 +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \ +--config configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k.yml \ +--use_vdl --num_workers 3 --do_eval +``` + +* 结果比较 + +实验mIoU曲线如下图所示。 +

+
+

+ + +|Loss|best mIoU| +|-|-| +|cross entropy loss|78.69%| +|lovasz softmax loss + cross entropy loss|79.18%| + +图中紫色曲线为lovasz hinge loss + cross entropy loss,蓝色曲线为cross entropy loss,相比提升0.5个百分点。 + +可看出使用lovasz hinge loss后,精度曲线全面高于原来的精度。 + + + +## 参考文献 +[Berman M, Rannen Triki A, Blaschko M B. The lovász-softmax loss: a tractable surrogate for the optimization of the intersection-over-union measure in neural networks[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018: 4413-4421.](http://openaccess.thecvf.com/content_cvpr_2018/html/Berman_The_LovaSz-Softmax_Loss_CVPR_2018_paper.html) diff --git a/paddleseg/cvlibs/config.py b/paddleseg/cvlibs/config.py index 5ab1d29872..742dd4757e 100644 --- a/paddleseg/cvlibs/config.py +++ b/paddleseg/cvlibs/config.py @@ -211,7 +211,9 @@ def loss(self) -> dict: if key == 'types': self._losses['types'] = [] for item in args['types']: - item['ignore_index'] = self.train_dataset.ignore_index + if item['type'] != 'MixedLoss': + item['ignore_index'] = \ + self.train_dataset.ignore_index self._losses['types'].append(self._load_object(item)) else: self._losses[key] = val diff --git a/paddleseg/datasets/__init__.py b/paddleseg/datasets/__init__.py index 31dc494fbd..047fd22a6c 100644 --- a/paddleseg/datasets/__init__.py +++ b/paddleseg/datasets/__init__.py @@ -18,3 +18,4 @@ from .ade import ADE20K from .optic_disc_seg import OpticDiscSeg from .pascal_context import PascalContext +from .mini_deep_globe_road_extraction import MiniDeepGlobeRoadExtraction diff --git a/paddleseg/datasets/mini_deep_globe_road_extraction.py b/paddleseg/datasets/mini_deep_globe_road_extraction.py new file mode 100644 index 0000000000..6921cc9bff --- /dev/null +++ b/paddleseg/datasets/mini_deep_globe_road_extraction.py @@ -0,0 +1,94 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from .dataset import Dataset +from paddleseg.utils.download import download_file_and_uncompress +from paddleseg.utils import seg_env +from paddleseg.cvlibs import manager +from paddleseg.transforms import Compose + +URL = "https://paddleseg.bj.bcebos.com/dataset/MiniDeepGlobeRoadExtraction.zip" + + +@manager.DATASETS.add_component +class MiniDeepGlobeRoadExtraction(Dataset): + """ + MiniDeepGlobeRoadExtraction dataset is extraced from DeepGlobe CVPR2018 challenge (http://deepglobe.org/) + + There are 800 images in the training set and 200 images in the validation set. + + Args: + dataset_root (str, optional): The dataset directory. Default: None. + transforms (list, optional): Transforms for image. Default: None. + mode (str, optional): Which part of dataset to use. It is one of ('train', 'val'). Default: 'train'. + edge (bool, optional): Whether to compute edge while training. Default: False. + """ + + def __init__(self, + dataset_root=None, + transforms=None, + mode='train', + edge=False): + self.dataset_root = dataset_root + self.transforms = Compose(transforms) + mode = mode.lower() + self.mode = mode + self.file_list = list() + self.num_classes = 2 + self.ignore_index = 255 + self.edge = edge + + if mode not in ['train', 'val']: + raise ValueError( + "`mode` should be 'train' or 'val', but got {}.".format(mode)) + + if self.transforms is None: + raise ValueError("`transforms` is necessary, but it is None.") + + if self.dataset_root is None: + self.dataset_root = download_file_and_uncompress( + url=URL, + savepath=seg_env.DATA_HOME, + extrapath=seg_env.DATA_HOME) + elif not os.path.exists(self.dataset_root): + self.dataset_root = os.path.normpath(self.dataset_root) + savepath, extraname = self.dataset_root.rsplit( + sep=os.path.sep, maxsplit=1) + self.dataset_root = download_file_and_uncompress( + url=URL, + savepath=savepath, + extrapath=savepath, + extraname=extraname) + + if mode == 'train': + file_path = os.path.join(self.dataset_root, 'train.txt') + else: + file_path = os.path.join(self.dataset_root, 'val.txt') + + with open(file_path, 'r') as f: + for line in f: + items = line.strip().split('|') + if len(items) != 2: + if mode == 'train' or mode == 'val': + raise Exception( + "File list format incorrect! It should be" + " image_name|label_name\\n") + image_path = os.path.join(self.dataset_root, items[0]) + grt_path = None + else: + image_path = os.path.join(self.dataset_root, items[0]) + grt_path = os.path.join(self.dataset_root, items[1]) + self.file_list.append([image_path, grt_path]) diff --git a/paddleseg/models/losses/__init__.py b/paddleseg/models/losses/__init__.py index e9332bdbf0..a0410448e8 100644 --- a/paddleseg/models/losses/__init__.py +++ b/paddleseg/models/losses/__init__.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .mixed_loss import MixedLoss from .cross_entropy_loss import CrossEntropyLoss from .binary_cross_entropy_loss import BCELoss +from .lovasz_loss import LovaszSoftmaxLoss, LovaszHingeLoss from .gscnn_dual_task_loss import DualTaskLoss from .edge_attention_loss import EdgeAttentionLoss from .bootstrapped_cross_entropy import BootstrappedCrossEntropyLoss diff --git a/paddleseg/models/losses/lovasz_loss.py b/paddleseg/models/losses/lovasz_loss.py new file mode 100644 index 0000000000..8c5c117b4a --- /dev/null +++ b/paddleseg/models/losses/lovasz_loss.py @@ -0,0 +1,222 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Lovasz-Softmax and Jaccard hinge loss in PaddlePaddle""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle +from paddle import nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager + + +@manager.LOSSES.add_component +class LovaszSoftmaxLoss(nn.Layer): + """ + Multi-class Lovasz-Softmax loss. + + Args: + ignore_index (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``. + classes (str|list): 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + + def __init__(self, ignore_index=255, classes='present'): + super(LovaszSoftmaxLoss, self).__init__() + self.ignore_index = ignore_index + self.classes = classes + + def forward(self, logits, labels): + """ + Forward computation. + + Args: + logits (Tensor): Shape is [N, C, H, W], logits at each prediction (between -\infty and +\infty). + labels (Tensor): Shape is [N, 1, H, W] or [N, H, W], ground truth labels (between 0 and C - 1). + """ + probas = F.softmax(logits, axis=1) + vprobas, vlabels = flatten_probas(probas, labels, self.ignore_index) + loss = lovasz_softmax_flat(vprobas, vlabels, classes=self.classes) + return loss + + +@manager.LOSSES.add_component +class LovaszHingeLoss(nn.Layer): + """ + Binary Lovasz hinge loss. + + Args: + ignore_index (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``. + """ + + def __init__(self, ignore_index=255): + super(LovaszHingeLoss, self).__init__() + self.ignore_index = ignore_index + + def forward(self, logits, labels): + """ + Forward computation. + + Args: + logits (Tensor): Shape is [N, 1, H, W] or [N, 2, H, W], logits at each pixel (between -\infty and +\infty). + labels (Tensor): Shape is [N, 1, H, W] or [N, H, W], binary ground truth masks (0 or 1). + """ + if logits.shape[1] == 2: + logits = binary_channel_to_unary(logits) + loss = lovasz_hinge_flat( + *flatten_binary_scores(logits, labels, self.ignore_index)) + return loss + + +def lovasz_grad(gt_sorted): + """ + Computes gradient of the Lovasz extension w.r.t sorted errors. + See Alg. 1 in paper. + """ + gts = paddle.sum(gt_sorted) + p = len(gt_sorted) + + intersection = gts - paddle.cumsum(gt_sorted, axis=0) + union = gts + paddle.cumsum(1 - gt_sorted, axis=0) + jaccard = 1.0 - intersection.cast('float32') / union.cast('float32') + + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def binary_channel_to_unary(logits, eps=1e-9): + """ + Converts binary channel logits to unary channel logits for lovasz hinge loss. + """ + probas = F.softmax(logits, axis=1) + probas = probas[:, 1, :, :] + logits = paddle.log(probas + eps / (1 - probas + eps)) + logits = logits.unsqueeze(1) + return logits + + +def lovasz_hinge_flat(logits, labels): + """ + Binary Lovasz hinge loss. + + Args: + logits (Tensor): Shape is [P], logits at each prediction (between -\infty and +\infty). + labels (Tensor): Shape is [P], binary ground truth labels (0 or 1). + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels - 1. + signs.stop_gradient = True + errors = 1. - logits * signs + errors_sorted, perm = paddle.fluid.core.ops.argsort(errors, 'axis', 0, + 'descending', True) + errors_sorted.stop_gradient = False + gt_sorted = paddle.gather(labels, perm) + grad = lovasz_grad(gt_sorted) + grad.stop_gradient = True + loss = paddle.sum(F.relu(errors_sorted) * grad) + return loss + + +def flatten_binary_scores(scores, labels, ignore=None): + """ + Flattens predictions in the batch (binary case). + Remove labels according to 'ignore'. + """ + scores = paddle.reshape(scores, [-1]) + labels = paddle.reshape(labels, [-1]) + labels.stop_gradient = True + if ignore is None: + return scores, labels + valid = labels != ignore + valid_mask = paddle.reshape(valid, (-1, 1)) + indexs = paddle.nonzero(valid_mask) + indexs.stop_gradient = True + vscores = paddle.gather(scores, indexs[:, 0]) + vlabels = paddle.gather(labels, indexs[:, 0]) + return vscores, vlabels + + +def lovasz_softmax_flat(probas, labels, classes='present'): + """ + Multi-class Lovasz-Softmax loss. + + Args: + probas (Tensor): Shape is [P, C], class probabilities at each prediction (between 0 and 1). + labels (Tensor): Shape is [P], ground truth labels (between 0 and C - 1). + classes (str|list): 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0. + C = probas.shape[1] + losses = [] + classes_to_sum = list(range(C)) if classes in ['all', 'present' + ] else classes + for c in classes_to_sum: + fg = paddle.cast(labels == c, probas.dtype) # foreground for class c + if classes == 'present' and fg.sum() == 0: + continue + fg.stop_gradient = True + if C == 1: + if len(classes_to_sum) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = paddle.abs(fg - class_pred) + errors_sorted, perm = paddle.fluid.core.ops.argsort( + errors, 'axis', 0, 'descending', True) + errors_sorted.stop_gradient = False + + fg_sorted = paddle.gather(fg, perm) + fg_sorted.stop_gradient = True + + grad = lovasz_grad(fg_sorted) + grad.stop_gradient = True + loss = paddle.sum(errors_sorted * grad) + losses.append(loss) + + if len(classes_to_sum) == 1: + return losses[0] + + losses_tensor = paddle.stack(losses) + mean_loss = paddle.mean(losses_tensor) + return mean_loss + + +def flatten_probas(probas, labels, ignore=None): + """ + Flattens predictions in the batch. + """ + if len(probas.shape) == 3: + probas = paddle.unsqueeze(probas, axis=1) + C = probas.shape[1] + probas = paddle.transpose(probas, [0, 2, 3, 1]) + probas = paddle.reshape(probas, [-1, C]) + labels = paddle.reshape(labels, [-1]) + if ignore is None: + return probas, labels + valid = labels != ignore + valid_mask = paddle.reshape(valid, [-1, 1]) + indexs = paddle.nonzero(valid_mask) + indexs.stop_gradient = True + vprobas = paddle.gather(probas, indexs[:, 0]) + vlabels = paddle.gather(labels, indexs[:, 0]) + return vprobas, vlabels diff --git a/paddleseg/models/losses/mixed_loss.py b/paddleseg/models/losses/mixed_loss.py new file mode 100644 index 0000000000..1a5ea91b52 --- /dev/null +++ b/paddleseg/models/losses/mixed_loss.py @@ -0,0 +1,58 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import paddle +from paddle import nn +import paddle.nn.functional as F + +from paddleseg.cvlibs import manager + + +@manager.LOSSES.add_component +class MixedLoss(nn.Layer): + """ + Weighted computations for multiple Loss. + The advantage is that mixed loss training can be achieved without changing the networking code. + + Args: + losses (list of nn.Layer): A list consisting of multiple loss classes + coef (float|int): Weighting coefficient of multiple loss + + Returns: + A callable object of MixedLoss. + """ + + def __init__(self, losses, coef): + super(MixedLoss, self).__init__() + if not isinstance(losses, list): + raise TypeError('`losses` must be a list!') + if not isinstance(coef, list): + raise TypeError('`coef` must be a list!') + len_losses = len(losses) + len_coef = len(coef) + if len_losses != len_coef: + raise ValueError( + 'The length of `losses` should equal to `coef`, but they are {} and {}.' + .format(len_losses, len_coef)) + + self.losses = losses + self.coef = coef + + def forward(self, logits, labels): + loss_list = [] + final_output = 0 + for i, loss in enumerate(self.losses): + output = loss(logits, labels) + final_output += output * self.coef[i] + return final_output