diff --git a/CHANGELOG.md b/CHANGELOG.md index dbdfc4834699f..b16ef25fa86c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `class_reduction` similar to sklearn for classification metrics ([#3322](https://github.com/PyTorchLightning/pytorch-lightning/pull/3322)) +- Changed IoU score behavior for classes absent in target and pred ([#3098](https://github.com/PyTorchLightning/pytorch-lightning/pull/3098)) + +- Changed IoU `remove_bg` bool to `ignore_index` optional int ([#3098](https://github.com/PyTorchLightning/pytorch-lightning/pull/3098)) + ### Deprecated diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index aa14d48ead6ed..57b3cdbb192ee 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -748,11 +748,11 @@ def __init__( include_background: whether to also compute dice for the background nan_score: score to return, if a NaN occurs during computation (denom zero) no_fg_score: score to return, if no foreground pixel was found in target - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied reduce_group: the process group to reduce metric results from DDP """ super().__init__( @@ -804,26 +804,45 @@ class IoU(TensorMetric): """ - def __init__(self, remove_bg: bool = False, reduction: str = "elementwise_mean"): + def __init__( + self, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, + num_classes: Optional[int] = None, + reduction: str = "elementwise_mean", + ): """ Args: - remove_bg: Flag to state whether a background class has been included - within input parameters. If true, will remove background class. If - false, return IoU over all classes. - Assumes that background is '0' class in input tensor - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: - - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + ignore_index: optional int specifying a target class to ignore. If given, this class index does not + contribute to the returned score, regardless of reduction method. Has no effect if given an int that is + not in the range [0, num_classes-1], where num_classes is either given or derived from pred and target. + By default, no index is ignored, and all classes are used. + absent_score: score to use for an individual class, if no instances of the class index were present in + `y_pred` AND no instances of the class index were present in `y_true`. For example, if we have 3 + classes, [0, 0] for `y_pred`, and [0, 2] for `y_true`, then class 1 would be assigned the + `absent_score`. Default is 0.0. + num_classes: Optionally specify the number of classes + reduction: a method to reduce metric score over labels. + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied """ super().__init__(name="iou") - self.remove_bg = remove_bg + self.ignore_index = ignore_index + self.absent_score = absent_score + self.num_classes = num_classes self.reduction = reduction def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, sample_weight: Optional[torch.Tensor] = None): """ Actual metric calculation. """ - return iou(y_pred, y_true, remove_bg=self.remove_bg, reduction=self.reduction) + return iou( + pred=y_pred, + target=y_true, + ignore_index=self.ignore_index, + absent_score=self.absent_score, + num_classes=self.num_classes, + reduction=self.reduction, + ) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 75c0ab358798a..3a47e7d6ad356 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -4,7 +4,7 @@ import torch from torch.nn import functional as F -from pytorch_lightning.metrics.functional.reduction import reduce, class_reduce +from pytorch_lightning.metrics.functional.reduction import class_reduce, reduce from pytorch_lightning.utilities import FLOAT16_EPSILON, rank_zero_warn @@ -921,12 +921,11 @@ def dice_score( bg: whether to also compute dice for the background nan_score: score to return, if a NaN occurs during computation no_fg_score: score to return, if no foreground pixel was found in target - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: Tensor containing dice score @@ -963,9 +962,10 @@ def dice_score( def iou( pred: torch.Tensor, target: torch.Tensor, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, num_classes: Optional[int] = None, - remove_bg: bool = False, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Intersection over union, or Jaccard index calculation. @@ -973,17 +973,20 @@ def iou( Args: pred: Tensor containing predictions target: Tensor containing targets + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that is not in the + range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no + index is ignored, and all classes are used. + absent_score: score to use for an individual class, if no instances of the class index were present in + `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, + [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. Default is + 0.0. num_classes: Optionally specify the number of classes - remove_bg: Flag to state whether a background class has been included - within input parameters. If true, will remove background class. If - false, return IoU over all classes - Assumes that background is '0' class in input tensor - reduction: a method to reduce metric score over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: IoU score : Tensor containing single value if reduction is @@ -998,12 +1001,39 @@ def iou( tensor(0.4914) """ + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) + tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes) - if remove_bg: - tps = tps[1:] - fps = fps[1:] - fns = fns[1:] - denom = fps + fns + tps - denom[denom == 0] = torch.tensor(FLOAT16_EPSILON).type_as(denom) - iou = tps / denom - return reduce(iou, reduction=reduction) + + scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32) + + for class_idx in range(num_classes): + if class_idx == ignore_index: + continue + + tp = tps[class_idx] + fp = fps[class_idx] + fn = fns[class_idx] + sup = sups[class_idx] + + # If this class is absent in the target (no support) AND absent in the pred (no true or false + # positives), then use the absent_score for this class. + if sup + tp + fp == 0: + scores[class_idx] = absent_score + continue + + denom = tp + fp + fn + # Note that we do not need to worry about division-by-zero here since we know (sup + tp + fp != 0) from above, + # which means ((tp+fn) + tp + fp != 0), which means (2tp + fp + fn != 0). Since all vars are non-negative, we + # can conclude (tp + fp + fn > 0), meaning the denominator is non-zero for each class. + score = tp.to(torch.float) / denom + scores[class_idx] = score + + # Remove the ignored class index from the scores. + if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: + scores = torch.cat([ + scores[:ignore_index], + scores[ignore_index + 1:], + ]) + + return reduce(scores, reduction=reduction) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index f8269384b3477..7466c5c4fe48e 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -4,6 +4,7 @@ import torch from sklearn.metrics import ( accuracy_score as sk_accuracy, + jaccard_score as sk_jaccard_score, precision_score as sk_precision, recall_score as sk_recall, f1_score as sk_f1_score, @@ -37,6 +38,7 @@ @pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ pytest.param(sk_accuracy, accuracy, id='accuracy'), + pytest.param(partial(sk_jaccard_score, average='macro'), iou, id='iou'), pytest.param(partial(sk_precision, average='micro'), precision, id='precision'), pytest.param(partial(sk_recall, average='micro'), recall, id='recall'), pytest.param(partial(sk_f1_score, average='micro'), f1_score, id='f1_score'), @@ -360,22 +362,95 @@ def test_dice_score(pred, target, expected): assert score == expected -@pytest.mark.parametrize(['half_ones', 'reduction', 'remove_bg', 'expected'], [ - pytest.param(False, 'none', False, torch.Tensor([1, 1, 1])), - pytest.param(False, 'elementwise_mean', False, torch.Tensor([1])), - pytest.param(False, 'none', True, torch.Tensor([1, 1])), - pytest.param(True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])), - pytest.param(True, 'elementwise_mean', False, torch.Tensor([0.5])), - pytest.param(True, 'none', True, torch.Tensor([0.5, 0.5])), +@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ + pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])), + pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])), + pytest.param(False, 'none', 0, torch.Tensor([1, 1])), + pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])), + pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])), + pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])), ]) -def test_iou(half_ones, reduction, remove_bg, expected): +def test_iou(half_ones, reduction, ignore_index, expected): pred = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) if half_ones: pred[:60] = 1 - iou_val = iou(pred, target, remove_bg=remove_bg, reduction=reduction) + iou_val = iou( + pred=pred, + target=target, + ignore_index=ignore_index, + reduction=reduction, + ) assert torch.allclose(iou_val, expected, atol=1e-9) +# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see +# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our +# `absent_score`. +@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [ + # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid + # scores the function can return ([0., 1.] range, inclusive). + # 2 classes, class 0 is correct everywhere, class 1 is absent. + pytest.param([0], [0], None, -1., 2, [1., -1.]), + pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]), + # absent_score not applied if only class 0 is present and it's the only class. + pytest.param([0], [0], None, -1., 1, [1.]), + # 2 classes, class 1 is correct everywhere, class 0 is absent. + pytest.param([1], [1], None, -1., 2, [-1., 1.]), + pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]), + # When 0 index ignored, class 0 does not get a score (not even the absent_score). + pytest.param([1], [1], 0, -1., 2, [1.0]), + # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. + pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]), + pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]), + # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. + pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]), + pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class + # 2 is absent. + pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class + # 2 is absent. + pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]), + # Sanity checks with absent_score of 1.0. + pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]), + pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]), +]) +def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): + iou_val = iou( + pred=torch.tensor(pred), + target=torch.tensor(target), + ignore_index=ignore_index, + absent_score=absent_score, + num_classes=num_classes, + reduction='none', + ) + assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) + + +@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [ + # Ignoring an index outside of [0, num_classes-1] should have no effect. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]), + # Ignoring a valid index drops only that index from the result. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]), + # When reducing to mean or sum, the ignored index does not contribute to the output. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]), +]) +def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): + iou_val = iou( + pred=torch.tensor(pred), + target=torch.tensor(target), + ignore_index=ignore_index, + num_classes=num_classes, + reduction=reduction, + ) + assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) + + # example data taken from # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py diff --git a/tests/metrics/test_classification.py b/tests/metrics/test_classification.py index 1be023b2af164..0559f360e10d4 100644 --- a/tests/metrics/test_classification.py +++ b/tests/metrics/test_classification.py @@ -226,9 +226,9 @@ def test_dice_coefficient(include_background): assert isinstance(dice, torch.Tensor) -@pytest.mark.parametrize('remove_bg', [True, False]) -def test_iou(remove_bg): - iou = IoU(remove_bg=remove_bg) +@pytest.mark.parametrize('ignore_index', [0, 1, None]) +def test_iou(ignore_index): + iou = IoU(ignore_index=ignore_index) assert iou.name == 'iou' score = iou(torch.randint(0, 1, (10, 25, 25)), diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index b1716485ed692..3c2f5b2fed9aa 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -63,7 +63,6 @@ def test_overfit_batch_limits(tmpdir): full_train_samples = len(train_loader) num_train_samples = int(0.11 * full_train_samples) - # ------------------------------------------------------ # set VAL and Test loaders # ------------------------------------------------------