diff --git a/README.md b/README.md
index 6786613ce4..9e08c8611e 100644
--- a/README.md
+++ b/README.md
@@ -93,6 +93,7 @@ python train.py --config configs/quick_start/bisenet_optic_disc_512x512_1k.yml
* [API Tutorial](https://aistudio.baidu.com/aistudio/projectdetail/1339458)
* [Data Preparation](./docs/data_prepare.md)
* [Training Configuration](./configs/)
+* [Loss Usage](./docs/loss_usage.md)
* [API References](./docs/apis)
* [Add New Components](./docs/add_new_model.md)
@@ -111,7 +112,7 @@ If you find our project useful in your research, please consider citing:
```latex
@misc{liu2021paddleseg,
- title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation},
+ title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation},
author={Yi Liu and Lutao Chu and Guowei Chen and Zewu Wu and Zeyu Chen and Baohua Lai and Yuying Hao},
year={2021},
eprint={2101.06175},
diff --git a/README_CN.md b/README_CN.md
index 949ddbf503..0b6bdeefff 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -95,6 +95,7 @@ python train.py --config configs/quick_start/bisenet_optic_disc_512x512_1k.yml
* [API使用教程](https://aistudio.baidu.com/aistudio/projectdetail/1339458)
* [数据集准备](./docs/data_prepare.md)
* [配置项](./configs/)
+* [Loss使用](./docs/loss_usage.md)
* [API参考](./docs/apis)
* [添加新组件](./docs/add_new_model.md)
@@ -114,7 +115,7 @@ python train.py --config configs/quick_start/bisenet_optic_disc_512x512_1k.yml
```latex
@misc{liu2021paddleseg,
- title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation},
+ title={PaddleSeg: A High-Efficient Development Toolkit for Image Segmentation},
author={Yi Liu and Lutao Chu and Guowei Chen and Zewu Wu and Zeyu Chen and Baohua Lai and Yuying Hao},
year={2021},
eprint={2101.06175},
diff --git a/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k_lovasz_softmax.yml b/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k_lovasz_softmax.yml
new file mode 100644
index 0000000000..8bbd59e6ac
--- /dev/null
+++ b/configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k_lovasz_softmax.yml
@@ -0,0 +1,35 @@
+_base_: '../_base_/cityscapes.yml'
+
+batch_size: 2
+iters: 160000
+
+model:
+ type: OCRNet
+ backbone:
+ type: HRNet_W18
+ pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
+ backbone_indices: [0]
+
+optimizer:
+ type: sgd
+
+learning_rate:
+ value: 0.01
+ decay:
+ type: poly
+ power: 0.9
+
+
+loss:
+ types:
+ - type: MixedLoss
+ losses:
+ - type: CrossEntropyLoss
+ - type: LovaszSoftmaxLoss
+ coef: [0.8, 0.2]
+ - type: MixedLoss
+ losses:
+ - type: CrossEntropyLoss
+ - type: LovaszSoftmaxLoss
+ coef: [0.8, 0.2]
+ coef: [1, 0.4]
diff --git a/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k.yml b/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k.yml
new file mode 100644
index 0000000000..d1366a43c1
--- /dev/null
+++ b/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k.yml
@@ -0,0 +1,48 @@
+
+
+batch_size: 4
+iters: 15000
+
+train_dataset:
+ type: MiniDeepGlobeRoadExtraction
+ dataset_root: data/MiniDeepGlobeRoadExtraction
+ transforms:
+ - type: ResizeStepScaling
+ min_scale_factor: 0.5
+ max_scale_factor: 2.0
+ scale_step_size: 0.25
+ - type: RandomPaddingCrop
+ crop_size: [768, 768]
+ - type: RandomHorizontalFlip
+ - type: Normalize
+ mode: train
+
+val_dataset:
+ type: MiniDeepGlobeRoadExtraction
+ dataset_root: data/MiniDeepGlobeRoadExtraction
+ transforms:
+ - type: Normalize
+ mode: val
+
+model:
+ type: OCRNet
+ backbone:
+ type: HRNet_W18
+ pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
+ backbone_indices: [0]
+
+optimizer:
+ type: sgd
+
+learning_rate:
+ value: 0.01
+ decay:
+ type: poly
+ power: 0.9
+
+
+loss:
+ types:
+ - type: CrossEntropyLoss
+ - type: CrossEntropyLoss
+ coef: [1, 0.4]
diff --git a/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k_lovasz_hinge.yml b/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k_lovasz_hinge.yml
new file mode 100644
index 0000000000..72f23e993a
--- /dev/null
+++ b/configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k_lovasz_hinge.yml
@@ -0,0 +1,55 @@
+
+
+batch_size: 4
+iters: 15000
+
+train_dataset:
+ type: MiniDeepGlobeRoadExtraction
+ dataset_root: data/MiniDeepGlobeRoadExtraction
+ transforms:
+ - type: ResizeStepScaling
+ min_scale_factor: 0.5
+ max_scale_factor: 2.0
+ scale_step_size: 0.25
+ - type: RandomPaddingCrop
+ crop_size: [768, 768]
+ - type: RandomHorizontalFlip
+ - type: Normalize
+ mode: train
+
+val_dataset:
+ type: MiniDeepGlobeRoadExtraction
+ dataset_root: data/MiniDeepGlobeRoadExtraction
+ transforms:
+ - type: Normalize
+ mode: val
+
+model:
+ type: OCRNet
+ backbone:
+ type: HRNet_W18
+ pretrained: https://bj.bcebos.com/paddleseg/dygraph/hrnet_w18_ssld.tar.gz
+ backbone_indices: [0]
+
+optimizer:
+ type: sgd
+
+learning_rate:
+ value: 0.01
+ decay:
+ type: poly
+ power: 0.9
+
+loss:
+ types:
+ - type: MixedLoss
+ losses:
+ - type: CrossEntropyLoss
+ - type: LovaszHingeLoss
+ coef: [1, 0.01]
+ - type: MixedLoss
+ losses:
+ - type: CrossEntropyLoss
+ - type: LovaszHingeLoss
+ coef: [1, 0.01]
+ coef: [1, 0.4]
diff --git a/docs/apis/README.md b/docs/apis/README.md
index e80296af67..ec797bdf8d 100644
--- a/docs/apis/README.md
+++ b/docs/apis/README.md
@@ -5,4 +5,5 @@
* [paddleseg.datasets](./datasets.md)
* [paddleseg.models](./models.md)
* [paddleseg.models.backbones](./backbones.md)
+* [paddleseg.models.losses](./losses.md)
* [paddleseg.transforms](./transforms.md)
diff --git a/docs/apis/losses.md b/docs/apis/losses.md
new file mode 100644
index 0000000000..7f799a8384
--- /dev/null
+++ b/docs/apis/losses.md
@@ -0,0 +1,33 @@
+# [paddleseg.models.losses](../../paddleseg/models/losses)
+
+## LovaszSoftmaxLoss
+> CLASS paddleseg.models.losses.LovaszSoftmaxLoss(ignore_index=255, classes='present')
+
+ Multi-class Lovasz-Softmax loss.
+
+> > Args
+> > > - **ignore_index** (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``.
+> > > - **classes** (str|list): 'all' for all, 'present' for classes present in labels, or a list of classes to average.
+
+
+## LovaszHingeLoss
+> CLASS paddleseg.models.losses.LovaszHingeLoss(ignore_index=255)
+
+ Binary Lovasz hinge loss.
+
+> > Args
+> > > - **ignore_index** (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``.
+
+
+## MixedLoss
+> CLASS paddleseg.models.losses.MixedLoss(losses, coef)
+
+ Weighted computations for multiple Loss.
+ The advantage is that mixed loss training can be achieved without changing the networking code.
+
+> > Args
+> > > - **losses** (list of nn.Layer): A list consisting of multiple loss classes
+> > > - **coef** (float|int): Weighting coefficient of multiple loss
+
+> > Returns
+> > > - A callable object of MixedLoss.
diff --git a/docs/images/Lovasz_Hinge_Evaluate_mIoU.png b/docs/images/Lovasz_Hinge_Evaluate_mIoU.png
new file mode 100644
index 0000000000..62de580163
Binary files /dev/null and b/docs/images/Lovasz_Hinge_Evaluate_mIoU.png differ
diff --git a/docs/images/Lovasz_Softmax_Evaluate_mIoU.png b/docs/images/Lovasz_Softmax_Evaluate_mIoU.png
new file mode 100644
index 0000000000..8e55dd3856
Binary files /dev/null and b/docs/images/Lovasz_Softmax_Evaluate_mIoU.png differ
diff --git a/docs/images/deepglobe.png b/docs/images/deepglobe.png
new file mode 100644
index 0000000000..0ffc311b5f
Binary files /dev/null and b/docs/images/deepglobe.png differ
diff --git a/docs/loss_usage.md b/docs/loss_usage.md
new file mode 100644
index 0000000000..b091133324
--- /dev/null
+++ b/docs/loss_usage.md
@@ -0,0 +1,4 @@
+# Loss usage
+
+- [Lovasz loss](lovasz_loss.md)
+- To be continued
diff --git a/docs/lovasz_loss.md b/docs/lovasz_loss.md
new file mode 100644
index 0000000000..edf15afdfc
--- /dev/null
+++ b/docs/lovasz_loss.md
@@ -0,0 +1,125 @@
+# Lovasz loss
+对于图像分割任务中,经常出现类别分布不均匀的情况,例如:工业产品的瑕疵检测、道路提取及病变区域提取等。我们可使用lovasz loss解决这个问题。
+
+Lovasz loss基于子模损失(submodular losses)的凸Lovasz扩展,对神经网络的mean IoU损失进行优化。Lovasz loss根据分割目标的类别数量可分为两种:lovasz hinge loss和lovasz softmax loss. 其中lovasz hinge loss适用于二分类问题,lovasz softmax loss适用于多分类问题。该工作发表在CVPR 2018上,可点击[参考文献](#参考文献)查看具体原理。
+
+
+## Lovasz loss使用指南
+接下来介绍如何使用lovasz loss进行训练。需要注意的是,通常的直接训练方式并一定管用,我们推荐另外2种训练方式:
+- (1)与cross entropy loss或bce loss(binary cross-entropy loss)加权结合使用。
+- (2)先使用cross entropy loss或bce loss进行训练,再使用lovasz softmax loss或lovasz hinge loss进行finetuning.
+
+以方式(1)为例,通过`MixedLoss`类选择训练时的损失函数, 通过`coef`参数对不同loss进行权重配比,从而灵活地进行训练调参。如下所示:
+
+```yaml
+loss:
+ types:
+ - type: MixedLoss
+ losses:
+ - type: CrossEntropyLoss
+ - type: LovaszSoftmaxLoss
+ coef: [0.8, 0.2]
+```
+
+```yaml
+loss:
+ types:
+ - type: MixedLoss
+ losses:
+ - type: CrossEntropyLoss
+ - type: LovaszHingeLoss
+ coef: [1, 0.02]
+```
+
+
+## Lovasz softmax loss实验对比
+
+接下来以经典的[Cityscapes](https://www.cityscapes-dataset.com/)数据集为例应用lovasz softmax loss. Cityscapes数据集共有19类目标,其中的类别并不均衡,例如类别`road`、`building`很常见,`fence`、`motocycle`、`wall`则较为罕见。我们将lovasz softmax loss与softmax loss进行了实验对比。这里使用OCRNet模型,backbone为HRNet w18.
+
+
+* 数据准备
+
+见[数据集准备教程](data_prepare.md)
+
+* Lovasz loss训练
+```shell
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \
+--config configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k_lovasz_softmax.yml \
+--use_vdl --num_workers 3 --do_eval
+```
+
+* Cross entropy loss训练
+```shell
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \
+--config configs/ocrnet/ocrnet_hrnetw18_cityscapes_1024x512_160k.yml \
+--use_vdl --num_workers 3 --do_eval
+```
+
+* 结果比较
+
+实验mIoU曲线如下图所示。
+
+
+
+
+
+|Loss|best mIoU|
+|-|-|
+|cross entropy loss|80.46%|
+|lovasz softmax loss + cross entropy loss|81.53%|
+
+图中蓝色曲线代表lovasz softmax loss + cross entropy loss,绿色曲线代表cross entropy loss,相比提升1个百分点。
+
+可看出使用lovasz softmax loss后,精度曲线基本都高于原来的精度。
+
+## Lovasz hinge loss实验对比
+
+我们以道路提取任务为例应用lovasz hinge loss.
+基于MiniDeepGlobeRoadExtraction数据集与cross entropy loss进行了实验对比。
+该数据集来源于[DeepGlobe CVPR2018挑战赛](http://deepglobe.org/)的Road Extraction单项,训练数据道路占比为 4.5%. 道路在整张图片中的比例很小,是典型的类别不均衡场景。图片样例如下:
+
+
+
+
+这里使用OCRNet模型,backbone为HRNet w18.
+
+* 数据集
+我们从DeepGlobe比赛的Road Extraction的训练集中随机抽取了800张图片作为训练集,200张图片作为验证集,
+制作了一个小型的道路提取数据集[MiniDeepGlobeRoadExtraction](https://paddleseg.bj.bcebos.com/dataset/MiniDeepGlobeRoadExtraction.zip)。
+运行训练脚本将自动下载该数据集。
+
+* Lovasz loss训练
+```shell
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \
+--config configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k_lovasz_hinge.yml \
+--use_vdl --num_workers 3 --do_eval
+```
+
+* Cross entropy loss训练
+```shell
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \
+--config configs/ocrnet/ocrnet_hrnetw18_road_extraction_768x768_15k.yml \
+--use_vdl --num_workers 3 --do_eval
+```
+
+* 结果比较
+
+实验mIoU曲线如下图所示。
+
+
+
+
+
+|Loss|best mIoU|
+|-|-|
+|cross entropy loss|78.69%|
+|lovasz softmax loss + cross entropy loss|79.18%|
+
+图中紫色曲线为lovasz hinge loss + cross entropy loss,蓝色曲线为cross entropy loss,相比提升0.5个百分点。
+
+可看出使用lovasz hinge loss后,精度曲线全面高于原来的精度。
+
+
+
+## 参考文献
+[Berman M, Rannen Triki A, Blaschko M B. The lovász-softmax loss: a tractable surrogate for the optimization of the intersection-over-union measure in neural networks[C]//Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018: 4413-4421.](http://openaccess.thecvf.com/content_cvpr_2018/html/Berman_The_LovaSz-Softmax_Loss_CVPR_2018_paper.html)
diff --git a/paddleseg/cvlibs/config.py b/paddleseg/cvlibs/config.py
index 5ab1d29872..742dd4757e 100644
--- a/paddleseg/cvlibs/config.py
+++ b/paddleseg/cvlibs/config.py
@@ -211,7 +211,9 @@ def loss(self) -> dict:
if key == 'types':
self._losses['types'] = []
for item in args['types']:
- item['ignore_index'] = self.train_dataset.ignore_index
+ if item['type'] != 'MixedLoss':
+ item['ignore_index'] = \
+ self.train_dataset.ignore_index
self._losses['types'].append(self._load_object(item))
else:
self._losses[key] = val
diff --git a/paddleseg/datasets/__init__.py b/paddleseg/datasets/__init__.py
index 31dc494fbd..047fd22a6c 100644
--- a/paddleseg/datasets/__init__.py
+++ b/paddleseg/datasets/__init__.py
@@ -18,3 +18,4 @@
from .ade import ADE20K
from .optic_disc_seg import OpticDiscSeg
from .pascal_context import PascalContext
+from .mini_deep_globe_road_extraction import MiniDeepGlobeRoadExtraction
diff --git a/paddleseg/datasets/mini_deep_globe_road_extraction.py b/paddleseg/datasets/mini_deep_globe_road_extraction.py
new file mode 100644
index 0000000000..6921cc9bff
--- /dev/null
+++ b/paddleseg/datasets/mini_deep_globe_road_extraction.py
@@ -0,0 +1,94 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+
+from .dataset import Dataset
+from paddleseg.utils.download import download_file_and_uncompress
+from paddleseg.utils import seg_env
+from paddleseg.cvlibs import manager
+from paddleseg.transforms import Compose
+
+URL = "https://paddleseg.bj.bcebos.com/dataset/MiniDeepGlobeRoadExtraction.zip"
+
+
+@manager.DATASETS.add_component
+class MiniDeepGlobeRoadExtraction(Dataset):
+ """
+ MiniDeepGlobeRoadExtraction dataset is extraced from DeepGlobe CVPR2018 challenge (http://deepglobe.org/)
+
+ There are 800 images in the training set and 200 images in the validation set.
+
+ Args:
+ dataset_root (str, optional): The dataset directory. Default: None.
+ transforms (list, optional): Transforms for image. Default: None.
+ mode (str, optional): Which part of dataset to use. It is one of ('train', 'val'). Default: 'train'.
+ edge (bool, optional): Whether to compute edge while training. Default: False.
+ """
+
+ def __init__(self,
+ dataset_root=None,
+ transforms=None,
+ mode='train',
+ edge=False):
+ self.dataset_root = dataset_root
+ self.transforms = Compose(transforms)
+ mode = mode.lower()
+ self.mode = mode
+ self.file_list = list()
+ self.num_classes = 2
+ self.ignore_index = 255
+ self.edge = edge
+
+ if mode not in ['train', 'val']:
+ raise ValueError(
+ "`mode` should be 'train' or 'val', but got {}.".format(mode))
+
+ if self.transforms is None:
+ raise ValueError("`transforms` is necessary, but it is None.")
+
+ if self.dataset_root is None:
+ self.dataset_root = download_file_and_uncompress(
+ url=URL,
+ savepath=seg_env.DATA_HOME,
+ extrapath=seg_env.DATA_HOME)
+ elif not os.path.exists(self.dataset_root):
+ self.dataset_root = os.path.normpath(self.dataset_root)
+ savepath, extraname = self.dataset_root.rsplit(
+ sep=os.path.sep, maxsplit=1)
+ self.dataset_root = download_file_and_uncompress(
+ url=URL,
+ savepath=savepath,
+ extrapath=savepath,
+ extraname=extraname)
+
+ if mode == 'train':
+ file_path = os.path.join(self.dataset_root, 'train.txt')
+ else:
+ file_path = os.path.join(self.dataset_root, 'val.txt')
+
+ with open(file_path, 'r') as f:
+ for line in f:
+ items = line.strip().split('|')
+ if len(items) != 2:
+ if mode == 'train' or mode == 'val':
+ raise Exception(
+ "File list format incorrect! It should be"
+ " image_name|label_name\\n")
+ image_path = os.path.join(self.dataset_root, items[0])
+ grt_path = None
+ else:
+ image_path = os.path.join(self.dataset_root, items[0])
+ grt_path = os.path.join(self.dataset_root, items[1])
+ self.file_list.append([image_path, grt_path])
diff --git a/paddleseg/models/losses/__init__.py b/paddleseg/models/losses/__init__.py
index e9332bdbf0..a0410448e8 100644
--- a/paddleseg/models/losses/__init__.py
+++ b/paddleseg/models/losses/__init__.py
@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from .mixed_loss import MixedLoss
from .cross_entropy_loss import CrossEntropyLoss
from .binary_cross_entropy_loss import BCELoss
+from .lovasz_loss import LovaszSoftmaxLoss, LovaszHingeLoss
from .gscnn_dual_task_loss import DualTaskLoss
from .edge_attention_loss import EdgeAttentionLoss
from .bootstrapped_cross_entropy import BootstrappedCrossEntropyLoss
diff --git a/paddleseg/models/losses/lovasz_loss.py b/paddleseg/models/losses/lovasz_loss.py
new file mode 100644
index 0000000000..8c5c117b4a
--- /dev/null
+++ b/paddleseg/models/losses/lovasz_loss.py
@@ -0,0 +1,222 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Lovasz-Softmax and Jaccard hinge loss in PaddlePaddle"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+
+from paddleseg.cvlibs import manager
+
+
+@manager.LOSSES.add_component
+class LovaszSoftmaxLoss(nn.Layer):
+ """
+ Multi-class Lovasz-Softmax loss.
+
+ Args:
+ ignore_index (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``.
+ classes (str|list): 'all' for all, 'present' for classes present in labels, or a list of classes to average.
+ """
+
+ def __init__(self, ignore_index=255, classes='present'):
+ super(LovaszSoftmaxLoss, self).__init__()
+ self.ignore_index = ignore_index
+ self.classes = classes
+
+ def forward(self, logits, labels):
+ """
+ Forward computation.
+
+ Args:
+ logits (Tensor): Shape is [N, C, H, W], logits at each prediction (between -\infty and +\infty).
+ labels (Tensor): Shape is [N, 1, H, W] or [N, H, W], ground truth labels (between 0 and C - 1).
+ """
+ probas = F.softmax(logits, axis=1)
+ vprobas, vlabels = flatten_probas(probas, labels, self.ignore_index)
+ loss = lovasz_softmax_flat(vprobas, vlabels, classes=self.classes)
+ return loss
+
+
+@manager.LOSSES.add_component
+class LovaszHingeLoss(nn.Layer):
+ """
+ Binary Lovasz hinge loss.
+
+ Args:
+ ignore_index (int64): Specifies a target value that is ignored and does not contribute to the input gradient. Default ``255``.
+ """
+
+ def __init__(self, ignore_index=255):
+ super(LovaszHingeLoss, self).__init__()
+ self.ignore_index = ignore_index
+
+ def forward(self, logits, labels):
+ """
+ Forward computation.
+
+ Args:
+ logits (Tensor): Shape is [N, 1, H, W] or [N, 2, H, W], logits at each pixel (between -\infty and +\infty).
+ labels (Tensor): Shape is [N, 1, H, W] or [N, H, W], binary ground truth masks (0 or 1).
+ """
+ if logits.shape[1] == 2:
+ logits = binary_channel_to_unary(logits)
+ loss = lovasz_hinge_flat(
+ *flatten_binary_scores(logits, labels, self.ignore_index))
+ return loss
+
+
+def lovasz_grad(gt_sorted):
+ """
+ Computes gradient of the Lovasz extension w.r.t sorted errors.
+ See Alg. 1 in paper.
+ """
+ gts = paddle.sum(gt_sorted)
+ p = len(gt_sorted)
+
+ intersection = gts - paddle.cumsum(gt_sorted, axis=0)
+ union = gts + paddle.cumsum(1 - gt_sorted, axis=0)
+ jaccard = 1.0 - intersection.cast('float32') / union.cast('float32')
+
+ if p > 1: # cover 1-pixel case
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
+ return jaccard
+
+
+def binary_channel_to_unary(logits, eps=1e-9):
+ """
+ Converts binary channel logits to unary channel logits for lovasz hinge loss.
+ """
+ probas = F.softmax(logits, axis=1)
+ probas = probas[:, 1, :, :]
+ logits = paddle.log(probas + eps / (1 - probas + eps))
+ logits = logits.unsqueeze(1)
+ return logits
+
+
+def lovasz_hinge_flat(logits, labels):
+ """
+ Binary Lovasz hinge loss.
+
+ Args:
+ logits (Tensor): Shape is [P], logits at each prediction (between -\infty and +\infty).
+ labels (Tensor): Shape is [P], binary ground truth labels (0 or 1).
+ """
+ if len(labels) == 0:
+ # only void pixels, the gradients should be 0
+ return logits.sum() * 0.
+ signs = 2. * labels - 1.
+ signs.stop_gradient = True
+ errors = 1. - logits * signs
+ errors_sorted, perm = paddle.fluid.core.ops.argsort(errors, 'axis', 0,
+ 'descending', True)
+ errors_sorted.stop_gradient = False
+ gt_sorted = paddle.gather(labels, perm)
+ grad = lovasz_grad(gt_sorted)
+ grad.stop_gradient = True
+ loss = paddle.sum(F.relu(errors_sorted) * grad)
+ return loss
+
+
+def flatten_binary_scores(scores, labels, ignore=None):
+ """
+ Flattens predictions in the batch (binary case).
+ Remove labels according to 'ignore'.
+ """
+ scores = paddle.reshape(scores, [-1])
+ labels = paddle.reshape(labels, [-1])
+ labels.stop_gradient = True
+ if ignore is None:
+ return scores, labels
+ valid = labels != ignore
+ valid_mask = paddle.reshape(valid, (-1, 1))
+ indexs = paddle.nonzero(valid_mask)
+ indexs.stop_gradient = True
+ vscores = paddle.gather(scores, indexs[:, 0])
+ vlabels = paddle.gather(labels, indexs[:, 0])
+ return vscores, vlabels
+
+
+def lovasz_softmax_flat(probas, labels, classes='present'):
+ """
+ Multi-class Lovasz-Softmax loss.
+
+ Args:
+ probas (Tensor): Shape is [P, C], class probabilities at each prediction (between 0 and 1).
+ labels (Tensor): Shape is [P], ground truth labels (between 0 and C - 1).
+ classes (str|list): 'all' for all, 'present' for classes present in labels, or a list of classes to average.
+ """
+ if probas.numel() == 0:
+ # only void pixels, the gradients should be 0
+ return probas * 0.
+ C = probas.shape[1]
+ losses = []
+ classes_to_sum = list(range(C)) if classes in ['all', 'present'
+ ] else classes
+ for c in classes_to_sum:
+ fg = paddle.cast(labels == c, probas.dtype) # foreground for class c
+ if classes == 'present' and fg.sum() == 0:
+ continue
+ fg.stop_gradient = True
+ if C == 1:
+ if len(classes_to_sum) > 1:
+ raise ValueError('Sigmoid output possible only with 1 class')
+ class_pred = probas[:, 0]
+ else:
+ class_pred = probas[:, c]
+ errors = paddle.abs(fg - class_pred)
+ errors_sorted, perm = paddle.fluid.core.ops.argsort(
+ errors, 'axis', 0, 'descending', True)
+ errors_sorted.stop_gradient = False
+
+ fg_sorted = paddle.gather(fg, perm)
+ fg_sorted.stop_gradient = True
+
+ grad = lovasz_grad(fg_sorted)
+ grad.stop_gradient = True
+ loss = paddle.sum(errors_sorted * grad)
+ losses.append(loss)
+
+ if len(classes_to_sum) == 1:
+ return losses[0]
+
+ losses_tensor = paddle.stack(losses)
+ mean_loss = paddle.mean(losses_tensor)
+ return mean_loss
+
+
+def flatten_probas(probas, labels, ignore=None):
+ """
+ Flattens predictions in the batch.
+ """
+ if len(probas.shape) == 3:
+ probas = paddle.unsqueeze(probas, axis=1)
+ C = probas.shape[1]
+ probas = paddle.transpose(probas, [0, 2, 3, 1])
+ probas = paddle.reshape(probas, [-1, C])
+ labels = paddle.reshape(labels, [-1])
+ if ignore is None:
+ return probas, labels
+ valid = labels != ignore
+ valid_mask = paddle.reshape(valid, [-1, 1])
+ indexs = paddle.nonzero(valid_mask)
+ indexs.stop_gradient = True
+ vprobas = paddle.gather(probas, indexs[:, 0])
+ vlabels = paddle.gather(labels, indexs[:, 0])
+ return vprobas, vlabels
diff --git a/paddleseg/models/losses/mixed_loss.py b/paddleseg/models/losses/mixed_loss.py
new file mode 100644
index 0000000000..1a5ea91b52
--- /dev/null
+++ b/paddleseg/models/losses/mixed_loss.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+
+from paddleseg.cvlibs import manager
+
+
+@manager.LOSSES.add_component
+class MixedLoss(nn.Layer):
+ """
+ Weighted computations for multiple Loss.
+ The advantage is that mixed loss training can be achieved without changing the networking code.
+
+ Args:
+ losses (list of nn.Layer): A list consisting of multiple loss classes
+ coef (float|int): Weighting coefficient of multiple loss
+
+ Returns:
+ A callable object of MixedLoss.
+ """
+
+ def __init__(self, losses, coef):
+ super(MixedLoss, self).__init__()
+ if not isinstance(losses, list):
+ raise TypeError('`losses` must be a list!')
+ if not isinstance(coef, list):
+ raise TypeError('`coef` must be a list!')
+ len_losses = len(losses)
+ len_coef = len(coef)
+ if len_losses != len_coef:
+ raise ValueError(
+ 'The length of `losses` should equal to `coef`, but they are {} and {}.'
+ .format(len_losses, len_coef))
+
+ self.losses = losses
+ self.coef = coef
+
+ def forward(self, logits, labels):
+ loss_list = []
+ final_output = 0
+ for i, loss in enumerate(self.losses):
+ output = loss(logits, labels)
+ final_output += output * self.coef[i]
+ return final_output