From 616fe82c7d793c2ca58b5eff3d152b4a10833a24 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 7 Oct 2020 20:26:30 +0200 Subject: [PATCH] latest restore func metrics (#3949) * latest restore * latest restore --- .../metrics/functional/__init__.py | 3 + .../metrics/functional/classification.py | 360 +++++++++++------- pytorch_lightning/metrics/functional/nlp.py | 23 +- .../metrics/functional/reduction.py | 41 ++ .../metrics/functional/regression.py | 157 ++++---- .../metrics/functional/self_supervised.py | 46 +++ .../metrics/functional/test_classification.py | 226 +++++++++-- tests/metrics/functional/test_reduction.py | 17 +- tests/metrics/functional/test_regression.py | 71 +++- .../functional/test_self_supervised.py | 35 ++ 10 files changed, 712 insertions(+), 267 deletions(-) create mode 100644 pytorch_lightning/metrics/functional/self_supervised.py create mode 100644 tests/metrics/functional/test_self_supervised.py diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 926803b5045e1..02928c803f19d 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -29,3 +29,6 @@ rmsle, ssim ) +from pytorch_lightning.metrics.functional.self_supervised import ( + embedding_similarity +) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index b6acf05a6401b..6a6189df816e0 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -1,12 +1,11 @@ -from collections import Sequence from functools import wraps -from typing import Optional, Tuple, Callable +from typing import Callable, Optional, Sequence, Tuple import torch +from pytorch_lightning.metrics.functional.reduction import class_reduce, reduce from torch.nn import functional as F -from pytorch_lightning.metrics.functional.reduction import reduce -from pytorch_lightning.utilities import rank_zero_warn, FLOAT16_EPSILON +from pytorch_lightning.utilities import rank_zero_warn def to_onehot( @@ -88,8 +87,10 @@ def get_num_classes( if num_classes is None: num_classes = num_all_classes elif num_classes != num_all_classes: - rank_zero_warn(f'You have set {num_classes} number of classes if different from' - f' predicted ({num_pred_classes}) and target ({num_target_classes}) number of classes') + rank_zero_warn(f'You have set {num_classes} number of classes which is' + f' different from predicted ({num_pred_classes}) and' + f' target ({num_target_classes}) number of classes', + RuntimeWarning) return num_classes @@ -138,10 +139,10 @@ def stat_scores_multiple_classes( target: torch.Tensor, num_classes: Optional[int] = None, argmax_dim: int = 1, + reduction: str = 'none', ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Calls the stat_scores function iteratively for all classes, thus - calculating the number of true postive, false postive, true negative + Calculates the number of true positive, false positive, true negative and false negative for each class Args: @@ -150,6 +151,12 @@ def stat_scores_multiple_classes( num_classes: number of classes if known argmax_dim: if pred is a tensor of probabilities, this indicates the axis the argmax transformation will be applied over + reduction: a method to reduce metric score over labels (default: none) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements Return: True Positive, False Positive, True Negative, False Negative, Support @@ -169,29 +176,73 @@ def stat_scores_multiple_classes( tensor([1., 0., 0., 0.]) >>> sups tensor([1., 0., 1., 1.]) - """ - num_classes = get_num_classes(pred=pred, target=target, - num_classes=num_classes) + """ if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) - tps = torch.zeros((num_classes,), device=pred.device) - fps = torch.zeros((num_classes,), device=pred.device) - tns = torch.zeros((num_classes,), device=pred.device) - fns = torch.zeros((num_classes,), device=pred.device) - sups = torch.zeros((num_classes,), device=pred.device) - for c in range(num_classes): - tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c) + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) - return tps, fps, tns, fns, sups + if pred.dtype != torch.bool: + pred = pred.clamp_max(max=num_classes) + if target.dtype != torch.bool: + target = target.clamp_max(max=num_classes) + + possible_reductions = ('none', 'sum', 'elementwise_mean') + if reduction not in possible_reductions: + raise ValueError("reduction type %s not supported" % reduction) + + if reduction == 'none': + pred = pred.view((-1, )).long() + target = target.view((-1, )).long() + + tps = torch.zeros((num_classes + 1,), device=pred.device) + fps = torch.zeros((num_classes + 1,), device=pred.device) + tns = torch.zeros((num_classes + 1,), device=pred.device) + fns = torch.zeros((num_classes + 1,), device=pred.device) + sups = torch.zeros((num_classes + 1,), device=pred.device) + + match_true = (pred == target).float() + match_false = 1 - match_true + + tps.scatter_add_(0, pred, match_true) + fps.scatter_add_(0, pred, match_false) + fns.scatter_add_(0, target, match_false) + tns = pred.size(0) - (tps + fps + fns) + sups.scatter_add_(0, target, torch.ones_like(match_true)) + + tps = tps[:num_classes] + fps = fps[:num_classes] + tns = tns[:num_classes] + fns = fns[:num_classes] + sups = sups[:num_classes] + + elif reduction == 'sum' or reduction == 'elementwise_mean': + count_match_true = (pred == target).sum().float() + oob_tp, oob_fp, oob_tn, oob_fn, oob_sup = stat_scores(pred, target, num_classes, argmax_dim) + + tps = count_match_true - oob_tp + fps = pred.nelement() - count_match_true - oob_fp + fns = pred.nelement() - count_match_true - oob_fn + tns = pred.nelement() * (num_classes + 1) - (tps + fps + fns + oob_tn) + sups = pred.nelement() - oob_sup.float() + + if reduction == 'elementwise_mean': + tps /= num_classes + fps /= num_classes + fns /= num_classes + tns /= num_classes + sups /= num_classes + + return tps.float(), fps.float(), tns.float(), fns.float(), sups.float() def accuracy( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction='elementwise_mean', + class_reduction: str = 'micro', + return_state: bool = False ) -> torch.Tensor: """ Computes the accuracy classification score @@ -200,15 +251,16 @@ def accuracy( pred: predicted labels target: ground truth labels num_classes: number of classes - reduction: a method for reducing accuracies over labels (default: takes the mean) - Available reduction methods: - - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements - + class_reduction: method to reduce metric score over labels + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: - A Tensor with the classification score. + A Tensor with the accuracy score. Example: @@ -220,20 +272,26 @@ def accuracy( """ tps, fps, tns, fns, sups = stat_scores_multiple_classes( pred=pred, target=target, num_classes=num_classes) - - if not (target > 0).any() and num_classes is None: - raise RuntimeError("cannot infer num_classes when target is all zero") - - if reduction in ('elementwise_mean', 'sum'): - return reduce(sum(tps) / sum(sups), reduction=reduction) - if reduction == 'none': - return reduce(tps / sups, reduction=reduction) + if return_state: + return {'tps': tps, 'sups': sups} + return class_reduce(tps, sups, sups, class_reduction=class_reduction) + + +def _confmat_normalize(cm): + """ Normalization function for confusion matrix """ + cm = cm / cm.sum(-1, keepdim=True) + nan_elements = cm[torch.isnan(cm)].nelement() + if nan_elements != 0: + cm[torch.isnan(cm)] = 0 + rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') + return cm def confusion_matrix( pred: torch.Tensor, target: torch.Tensor, normalize: bool = False, + num_classes: Optional[int] = None ) -> torch.Tensor: """ Computes the confusion matrix C where each entry C_{i,j} is the number of observations @@ -243,6 +301,7 @@ def confusion_matrix( pred: estimated targets target: ground truth labels normalize: normalizes confusion matrix + num_classes: number of classes Return: Tensor, confusion matrix C [num_classes, num_classes ] @@ -257,15 +316,15 @@ def confusion_matrix( [0., 0., 1., 0.], [0., 0., 0., 1.]]) """ - num_classes = get_num_classes(pred, target, None) + num_classes = get_num_classes(pred, target, num_classes) - unique_labels = target.view(-1) * num_classes + pred.view(-1) + unique_labels = (target.view(-1) * num_classes + pred.view(-1)).to(torch.int) bins = torch.bincount(unique_labels, minlength=num_classes ** 2) cm = bins.reshape(num_classes, num_classes).squeeze().float() if normalize: - cm = cm / cm.sum(-1) + cm = _confmat_normalize(cm) return cm @@ -274,7 +333,9 @@ def precision_recall( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + class_reduction: str = 'micro', + return_support: bool = False, + return_state: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: """ Computes precision and recall for different thresholds @@ -283,12 +344,16 @@ def precision_recall( pred: estimated probabilities target: ground-truth labels num_classes: number of classes - reduction: method for reducing precision-recall values (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + + return_support: returns the support for each class, need for fbeta/f1 calculations + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with precision and recall @@ -296,26 +361,19 @@ def precision_recall( Example: >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) - >>> precision_recall(x, y) - (tensor(0.7500), tensor(0.6250)) + >>> y = torch.tensor([0, 2, 2, 2]) + >>> precision_recall(x, y, class_reduction='macro') + (tensor(0.5000), tensor(0.3333)) """ tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) - tps = tps.to(torch.float) - fps = fps.to(torch.float) - fns = fns.to(torch.float) - - precision = tps / (tps + fps) - recall = tps / (tps + fns) - - # solution by justus, see https://discuss.pytorch.org/t/how-to-set-nan-in-tensor-to-0/3918/9 - precision[precision != precision] = 0 - recall[recall != recall] = 0 - - precision = reduce(precision, reduction=reduction) - recall = reduce(recall, reduction=reduction) + precision = class_reduce(tps, tps + fps, sups, class_reduction=class_reduction) + recall = class_reduce(tps, tps + fns, sups, class_reduction=class_reduction) + if return_state: + return {'tps': tps, 'fps': fps, 'fns': fns, 'sups': sups} + if return_support: + return precision, recall, sups return precision, recall @@ -323,7 +381,7 @@ def precision( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes precision score. @@ -332,12 +390,12 @@ def precision( pred: estimated probabilities target: ground-truth labels num_classes: number of classes - reduction: method for reducing precision values (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class Return: Tensor with precision. @@ -351,14 +409,14 @@ def precision( """ return precision_recall(pred=pred, target=target, - num_classes=num_classes, reduction=reduction)[0] + num_classes=num_classes, class_reduction=class_reduction)[0] def recall( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes recall score. @@ -367,12 +425,12 @@ def recall( pred: estimated probabilities target: ground-truth labels num_classes: number of classes - reduction: method for reducing recall values (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class Return: Tensor with recall. @@ -382,10 +440,10 @@ def recall( >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 2, 2]) >>> recall(x, y) - tensor(0.6250) + tensor(0.7500) """ return precision_recall(pred=pred, target=target, - num_classes=num_classes, reduction=reduction)[1] + num_classes=num_classes, class_reduction=class_reduction)[1] def fbeta_score( @@ -393,7 +451,7 @@ def fbeta_score( target: torch.Tensor, beta: float, num_classes: Optional[int] = None, - reduction: str = 'elementwise_mean', + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes the F-beta score which is a weighted harmonic mean of precision and recall. @@ -408,12 +466,12 @@ def fbeta_score( beta = 0: only precision beta -> inf: only recall num_classes: number of classes - reduction: method for reducing F-score (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements. + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class Return: Tensor with the value of F-score. It is a value between 0-1. @@ -423,27 +481,27 @@ def fbeta_score( >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 2, 2]) >>> fbeta_score(x, y, 0.2) - tensor(0.7407) + tensor(0.7500) """ - prec, rec = precision_recall(pred=pred, target=target, - num_classes=num_classes, - reduction='none') - - nom = (1 + beta ** 2) * prec * rec + # We need to differentiate at which point to do class reduction + intermidiate_reduction = 'none' if class_reduction != "micro" else 'micro' + + prec, rec, sups = precision_recall(pred=pred, target=target, + num_classes=num_classes, + class_reduction=intermidiate_reduction, + return_support=True) + num = (1 + beta ** 2) * prec * rec denom = ((beta ** 2) * prec + rec) - fbeta = nom / denom - - # drop NaN after zero division - fbeta[fbeta != fbeta] = 0 - - return reduce(fbeta, reduction=reduction) + if intermidiate_reduction == 'micro': + return torch.sum(num) / torch.sum(denom) + return class_reduce(num, denom, sups, class_reduction=class_reduction) def f1_score( pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, - reduction='elementwise_mean', + class_reduction: str = 'micro', ) -> torch.Tensor: """ Computes the F1-score (a.k.a F-measure), which is the harmonic mean of the precision and recall. @@ -453,12 +511,12 @@ def f1_score( pred: estimated probabilities target: ground-truth labels num_classes: number of classes - reduction: method for reducing F1-score (default: takes the mean) - Available reduction methods: + class_reduction: method to reduce metric score over labels - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements. + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class Return: Tensor containing F1-score @@ -468,10 +526,10 @@ def f1_score( >>> x = torch.tensor([0, 1, 2, 3]) >>> y = torch.tensor([0, 1, 2, 2]) >>> f1_score(x, y) - tensor(0.6667) + tensor(0.7500) """ return fbeta_score(pred=pred, target=target, beta=1., - num_classes=num_classes, reduction=reduction) + num_classes=num_classes, class_reduction=class_reduction) def _binary_clf_curve( @@ -539,12 +597,12 @@ def roc( Example: >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) + >>> y = torch.tensor([0, 1, 1, 1]) >>> fpr, tpr, thresholds = roc(x, y) >>> fpr - tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]) + tensor([0., 0., 0., 0., 1.]) >>> tpr - tensor([0., 0., 0., 1., 1.]) + tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000]) >>> thresholds tensor([4, 3, 2, 1, 0]) @@ -637,12 +695,12 @@ def precision_recall_curve( Example: >>> pred = torch.tensor([0, 1, 2, 3]) - >>> target = torch.tensor([0, 1, 2, 2]) + >>> target = torch.tensor([0, 1, 1, 0]) >>> precision, recall, thresholds = precision_recall_curve(pred, target) >>> precision - tensor([0.3333, 0.0000, 0.0000, 1.0000]) + tensor([0.6667, 0.5000, 0.0000, 1.0000]) >>> recall - tensor([1., 0., 0., 0.]) + tensor([1.0000, 0.5000, 0.0000, 0.0000]) >>> thresholds tensor([1, 2, 3]) @@ -813,10 +871,14 @@ def auroc( Example: >>> x = torch.tensor([0, 1, 2, 3]) - >>> y = torch.tensor([0, 1, 2, 2]) + >>> y = torch.tensor([0, 1, 1, 0]) >>> auroc(x, y) - tensor(0.3333) + tensor(0.5000) """ + if any(target > 1): + raise ValueError('AUROC metric is meant for binary classification, but' + ' target tensor contains value different from 0 and 1.' + ' Multiclass is currently not supported.') @auc_decorator(reorder=True) def _auroc(pred, target, sample_weight, pos_label): @@ -876,12 +938,11 @@ def dice_score( bg: whether to also compute dice for the background nan_score: score to return, if a NaN occurs during computation no_fg_score: score to return, if no foreground pixel was found in target - reduction: a method for reducing accuracies over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: Tensor containing dice score @@ -918,9 +979,10 @@ def dice_score( def iou( pred: torch.Tensor, target: torch.Tensor, + ignore_index: Optional[int] = None, + absent_score: float = 0.0, num_classes: Optional[int] = None, - remove_bg: bool = False, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ Intersection over union, or Jaccard index calculation. @@ -928,17 +990,20 @@ def iou( Args: pred: Tensor containing predictions target: Tensor containing targets + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that is not in the + range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no + index is ignored, and all classes are used. + absent_score: score to use for an individual class, if no instances of the class index were present in + `pred` AND no instances of the class index were present in `target`. For example, if we have 3 classes, + [0, 0] for `pred`, and [0, 2] for `target`, then class 1 would be assigned the `absent_score`. Default is + 0.0. num_classes: Optionally specify the number of classes - remove_bg: Flag to state whether a background class has been included - within input parameters. If true, will remove background class. If - false, return IoU over all classes - Assumes that background is '0' class in input tensor - reduction: a method for reducing IoU over labels (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: IoU score : Tensor containing single value if reduction is @@ -953,12 +1018,39 @@ def iou( tensor(0.4914) """ + num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) + tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes) - if remove_bg: - tps = tps[1:] - fps = fps[1:] - fns = fns[1:] - denom = fps + fns + tps - denom[denom == 0] = torch.tensor(FLOAT16_EPSILON).type_as(denom) - iou = tps / denom - return reduce(iou, reduction=reduction) + + scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32) + + for class_idx in range(num_classes): + if class_idx == ignore_index: + continue + + tp = tps[class_idx] + fp = fps[class_idx] + fn = fns[class_idx] + sup = sups[class_idx] + + # If this class is absent in the target (no support) AND absent in the pred (no true or false + # positives), then use the absent_score for this class. + if sup + tp + fp == 0: + scores[class_idx] = absent_score + continue + + denom = tp + fp + fn + # Note that we do not need to worry about division-by-zero here since we know (sup + tp + fp != 0) from above, + # which means ((tp+fn) + tp + fp != 0), which means (2tp + fp + fn != 0). Since all vars are non-negative, we + # can conclude (tp + fp + fn > 0), meaning the denominator is non-zero for each class. + score = tp.to(torch.float) / denom + scores[class_idx] = score + + # Remove the ignored class index from the scores. + if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: + scores = torch.cat([ + scores[:ignore_index], + scores[ignore_index + 1:], + ]) + + return reduce(scores, reduction=reduction) diff --git a/pytorch_lightning/metrics/functional/nlp.py b/pytorch_lightning/metrics/functional/nlp.py index 22645bb5494b6..85c33642704cd 100644 --- a/pytorch_lightning/metrics/functional/nlp.py +++ b/pytorch_lightning/metrics/functional/nlp.py @@ -4,13 +4,14 @@ # Date: 2020-07-18 # Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score from collections import Counter -from typing import Sequence, List +from typing import List, Sequence import torch def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: - """Counting how many times each word appears in a given text with ngram + """ + Counting how many times each word appears in a given text with ngram Args: ngram_input_list: A list of translated text or reference texts @@ -24,16 +25,20 @@ def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: for i in range(1, n_gram + 1): for j in range(len(ngram_input_list) - i + 1): - ngram_key = tuple(ngram_input_list[j : i + j]) + ngram_key = tuple(ngram_input_list[j:(i + j)]) ngram_counter[ngram_key] += 1 return ngram_counter def bleu_score( - translate_corpus: Sequence[str], reference_corpus: Sequence[str], n_gram: int = 4, smooth: bool = False + translate_corpus: Sequence[str], + reference_corpus: Sequence[str], + n_gram: int = 4, + smooth: bool = False ) -> torch.Tensor: - """Calculate BLEU score of machine translated text with one or more references. + """ + Calculate BLEU score of machine translated text with one or more references Args: translate_corpus: An iterable of machine translated corpus @@ -42,7 +47,7 @@ def bleu_score( smooth: Whether or not to apply smoothing – Lin et al. 2004 Return: - A Tensor with BLEU Score + Tensor with BLEU Score Example: @@ -50,6 +55,7 @@ def bleu_score( >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] >>> bleu_score(translate_corpus, reference_corpus) tensor(0.7598) + """ assert len(translate_corpus) == len(reference_corpus) @@ -58,6 +64,7 @@ def bleu_score( precision_scores = torch.zeros(n_gram) c = 0.0 r = 0.0 + for (translation, references) in zip(translate_corpus, reference_corpus): c += len(translation) ref_len_list = [len(ref) for ref in references] @@ -65,10 +72,12 @@ def bleu_score( r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] translation_counter = _count_ngram(translation, n_gram) reference_counter = Counter() + for ref in references: reference_counter |= _count_ngram(ref, n_gram) ngram_counter_clip = translation_counter & reference_counter + for counter_clip in ngram_counter_clip: numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] @@ -77,6 +86,7 @@ def bleu_score( trans_len = torch.tensor(c) ref_len = torch.tensor(r) + if min(numerator) == 0.0: return torch.tensor(0.0) @@ -84,6 +94,7 @@ def bleu_score( precision_scores = torch.add(numerator, torch.ones(n_gram)) / torch.add(denominator, torch.ones(n_gram)) else: precision_scores = numerator / denominator + log_precision_scores = torch.tensor([1.0 / n_gram] * n_gram) * torch.log(precision_scores) geometric_mean = torch.exp(torch.sum(log_precision_scores)) brevity_penalty = torch.tensor(1.0) if c > r else torch.exp(1 - (ref_len / trans_len)) diff --git a/pytorch_lightning/metrics/functional/reduction.py b/pytorch_lightning/metrics/functional/reduction.py index b9be8ca7daeb5..d0618abd65b96 100644 --- a/pytorch_lightning/metrics/functional/reduction.py +++ b/pytorch_lightning/metrics/functional/reduction.py @@ -22,3 +22,44 @@ def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: if reduction == 'sum': return torch.sum(to_reduce) raise ValueError('Reduction parameter unknown.') + + +def class_reduce(num: torch.Tensor, + denom: torch.Tensor, + weights: torch.Tensor, + class_reduction: str = 'none') -> torch.Tensor: + """ + Function used to reduce classification metrics of the form `num / denom * weights`. + For example for calculating standard accuracy the num would be number of + true positives per class, denom would be the support per class, and weights + would be a tensor of 1s + + Args: + num: numerator tensor + decom: denominator tensor + weights: weights for each class + class_reduction: reduction method for multiclass problems + + - ``'micro'``: calculate metrics globally (default) + - ``'macro'``: calculate metrics for each label, and find their unweighted mean. + - ``'weighted'``: calculate metrics for each label, and find their weighted mean. + - ``'none'``: returns calculated metric per class + + """ + valid_reduction = ('micro', 'macro', 'weighted', 'none') + if class_reduction == 'micro': + return torch.sum(num) / torch.sum(denom) + + # For the rest we need to take care of instances where the denom can be 0 + # for some classes which will produce nans for that class + fraction = num / denom + fraction[fraction != fraction] = 0 + if class_reduction == 'macro': + return torch.mean(fraction) + elif class_reduction == 'weighted': + return torch.sum(fraction * (weights / torch.sum(weights))) + elif class_reduction == 'none': + return fraction + + raise ValueError(f'Reduction parameter {class_reduction} unknown.' + f' Choose between one of these: {valid_reduction}') diff --git a/pytorch_lightning/metrics/functional/regression.py b/pytorch_lightning/metrics/functional/regression.py index 6ad5ee6cfbec9..b7e360b9c196d 100644 --- a/pytorch_lightning/metrics/functional/regression.py +++ b/pytorch_lightning/metrics/functional/regression.py @@ -1,15 +1,15 @@ from typing import Sequence import torch -from torch.nn import functional as F - from pytorch_lightning.metrics.functional.reduction import reduce +from torch.nn import functional as F def mse( pred: torch.Tensor, target: torch.Tensor, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', + return_state: bool = False ) -> torch.Tensor: """ Computes mean squared error @@ -17,12 +17,13 @@ def mse( Args: pred: estimated labels target: ground truth labels - reduction: method for reducing mse (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with MSE @@ -36,6 +37,8 @@ def mse( """ mse = F.mse_loss(pred, target, reduction='none') + if return_state: + return {'squared_error': mse.sum(), 'n_observations': torch.tensor(mse.numel())} mse = reduce(mse, reduction=reduction) return mse @@ -43,7 +46,8 @@ def mse( def rmse( pred: torch.Tensor, target: torch.Tensor, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', + return_state: bool = False ) -> torch.Tensor: """ Computes root mean squared error @@ -51,12 +55,13 @@ def rmse( Args: pred: estimated labels target: ground truth labels - reduction: method for reducing rmse (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with RMSE @@ -68,14 +73,18 @@ def rmse( tensor(0.5000) """ - rmse = torch.sqrt(mse(pred, target, reduction=reduction)) - return rmse + mean_squared_error = mse(pred, target, reduction=reduction) + if return_state: + return {'squared_error': mean_squared_error.sum(), + 'n_observations': torch.tensor(mean_squared_error.numel())} + return torch.sqrt(mean_squared_error) def mae( pred: torch.Tensor, target: torch.Tensor, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', + return_state: bool = False ) -> torch.Tensor: """ Computes mean absolute error @@ -83,12 +92,13 @@ def mae( Args: pred: estimated labels target: ground truth labels - reduction: method for reducing mae (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with MAE @@ -102,6 +112,8 @@ def mae( """ mae = F.l1_loss(pred, target, reduction='none') + if return_state: + return {'absolute_error': mae.sum(), 'n_observations': torch.tensor(mae.numel())} mae = reduce(mae, reduction=reduction) return mae @@ -117,12 +129,11 @@ def rmsle( Args: pred: estimated labels target: ground truth labels - reduction: method for reducing rmsle (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied Return: Tensor with RMSLE @@ -132,10 +143,10 @@ def rmsle( >>> x = torch.tensor([0., 1, 2, 3]) >>> y = torch.tensor([0., 1, 2, 2]) >>> rmsle(x, y) - tensor(0.0207) + tensor(0.1438) """ - rmsle = mse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction) + rmsle = rmse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction) return rmsle @@ -144,7 +155,8 @@ def psnr( target: torch.Tensor, data_range: float = None, base: float = 10.0, - reduction: str = 'elementwise_mean' + reduction: str = 'elementwise_mean', + return_state: bool = False ) -> torch.Tensor: """ Computes the peak signal-to-noise ratio @@ -154,12 +166,13 @@ def psnr( target: groun truth signal data_range: the range of the data. If None, it is determined from the data (max - min) base: a base of a logarithm to use (default: 10) - reduction: method for reducing psnr (default: takes the mean) - Available reduction methods: + reduction: a method to reduce metric score over labels. - - elementwise_mean: takes the mean - - none: pass array - - sum add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + return_state: returns a internal state that can be ddp reduced + before doing the final calculation Return: Tensor with PSNR score @@ -172,12 +185,16 @@ def psnr( tensor(2.5527) """ - if data_range is None: - data_range = max(target.max() - target.min(), pred.max() - pred.min()) + data_range = target.max() - target.min() else: data_range = torch.tensor(float(data_range)) + if return_state: + return {'data_range': data_range, + 'sum_squared_error': F.mse_loss(pred, target, reduction='none').sum(), + 'n_obs': torch.tensor(target.numel())} + mse_score = mse(pred.view(-1), target.view(-1), reduction=reduction) psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score) psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) @@ -185,16 +202,19 @@ def psnr( def _gaussian_kernel(channel, kernel_size, sigma, device): - def gaussian(kernel_size, sigma, device): + def _gaussian(kernel_size, sigma, device): gauss = torch.arange( - start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device + start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, + step=1, + dtype=torch.float32, + device=device ) gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2))) return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size) - gaussian_kernel_x = gaussian(kernel_size[0], sigma[0], device) - gaussian_kernel_y = gaussian(kernel_size[1], sigma[1], device) - kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size) + gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], device) + gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], device) + kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) return kernel.expand(channel, 1, kernel_size[0], kernel_size[1]) @@ -213,32 +233,31 @@ def ssim( Computes Structual Similarity Index Measure Args: - pred: Estimated image - target: Ground truth image - kernel_size: Size of the gaussian kernel. Default: (11, 11) - sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5) - reduction: A method for reducing ssim over all elements in the ``pred`` tensor. Default: ``elementwise_mean`` + pred: estimated image + target: ground truth image + kernel_size: size of the gaussian kernel (default: (11, 11)) + sigma: Standard deviation of the gaussian kernel (default: (1.5, 1.5)) + reduction: a method to reduce metric score over labels. - Available reduction methods: - - elementwise_mean: takes the mean - - none: pass away - - sum: add elements + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied data_range: Range of the image. If ``None``, it is determined from the image (max - min) k1: Parameter of SSIM. Default: 0.01 k2: Parameter of SSIM. Default: 0.03 - Returns: - A Tensor with SSIM + Return: + Tensor with SSIM score Example: >>> pred = torch.rand([16, 1, 16, 16]) - >>> target = pred * 1.25 + >>> target = pred * 0.75 >>> ssim(pred, target) - tensor(0.9520) - """ + tensor(0.9219) + """ if pred.dtype != target.dtype: raise TypeError( "Expected `pred` and `target` to have the same data type." @@ -278,16 +297,24 @@ def ssim( channel = pred.size(1) kernel = _gaussian_kernel(channel, kernel_size, sigma, device) - mu_pred = F.conv2d(pred, kernel, groups=channel) - mu_target = F.conv2d(target, kernel, groups=channel) - - mu_pred_sq = mu_pred.pow(2) - mu_target_sq = mu_target.pow(2) - mu_pred_target = mu_pred * mu_target - sigma_pred_sq = F.conv2d(pred * pred, kernel, groups=channel) - mu_pred_sq - sigma_target_sq = F.conv2d(target * target, kernel, groups=channel) - mu_target_sq - sigma_pred_target = F.conv2d(pred * target, kernel, groups=channel) - mu_pred_target + # Concatenate + # pred for mu_pred + # target for mu_target + # pred * pred for sigma_pred + # target * target for sigma_target + # pred * target for sigma_pred_target + input_list = torch.cat([pred, target, pred * pred, target * target, pred * target]) # (5 * B, C, H, W) + outputs = F.conv2d(input_list, kernel, groups=channel) + output_list = [outputs[x * pred.size(0): (x + 1) * pred.size(0)] for x in range(len(outputs))] + + mu_pred_sq = output_list[0].pow(2) + mu_target_sq = output_list[1].pow(2) + mu_pred_target = output_list[0] * output_list[1] + + sigma_pred_sq = output_list[2] - mu_pred_sq + sigma_target_sq = output_list[3] - mu_target_sq + sigma_pred_target = output_list[4] - mu_pred_target UPPER = 2 * sigma_pred_target + C2 LOWER = sigma_pred_sq + sigma_target_sq + C2 diff --git a/pytorch_lightning/metrics/functional/self_supervised.py b/pytorch_lightning/metrics/functional/self_supervised.py new file mode 100644 index 0000000000000..c8c7e83166723 --- /dev/null +++ b/pytorch_lightning/metrics/functional/self_supervised.py @@ -0,0 +1,46 @@ +import torch + + +def embedding_similarity( + batch: torch.Tensor, + similarity: str = 'cosine', + reduction: str = 'none', + zero_diagonal: bool = True +) -> torch.Tensor: + """ + Computes representation similarity + + Example: + + >>> embeddings = torch.tensor([[1., 2., 3., 4.], [1., 2., 3., 4.], [4., 5., 6., 7.]]) + >>> embedding_similarity(embeddings) + tensor([[0.0000, 1.0000, 0.9759], + [1.0000, 0.0000, 0.9759], + [0.9759, 0.9759, 0.0000]]) + + Args: + batch: (batch, dim) + similarity: 'dot' or 'cosine' + reduction: 'none', 'sum', 'mean' (all along dim -1) + zero_diagonal: if True, the diagonals are set to zero + + Return: + A square matrix (batch, batch) with the similarity scores between all elements + If sum or mean are used, then returns (b, 1) with the reduced value for each row + """ + if similarity == 'cosine': + norm = torch.norm(batch, p=2, dim=1) + batch = batch / norm.unsqueeze(1) + + sqr_mtx = batch.mm(batch.transpose(1, 0)) + + if zero_diagonal: + sqr_mtx = sqr_mtx.fill_diagonal_(0) + + if reduction == 'mean': + sqr_mtx = sqr_mtx.mean(dim=-1) + + if reduction == 'sum': + sqr_mtx = sqr_mtx.sum(dim=-1) + + return sqr_mtx diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index c9e1f0892f6e7..9afdf84fa8770 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -4,11 +4,15 @@ import torch from sklearn.metrics import ( accuracy_score as sk_accuracy, + jaccard_score as sk_jaccard_score, precision_score as sk_precision, recall_score as sk_recall, f1_score as sk_f1_score, fbeta_score as sk_fbeta_score, confusion_matrix as sk_confusion_matrix, + roc_curve as sk_roc_curve, + roc_auc_score as sk_roc_auc_score, + precision_recall_curve as sk_precision_recall_curve ) from pytorch_lightning import seed_everything @@ -35,28 +39,65 @@ ) -@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ - pytest.param(sk_accuracy, accuracy, id='accuracy'), - pytest.param(partial(sk_precision, average='macro'), precision, id='precision'), - pytest.param(partial(sk_recall, average='macro'), recall, id='recall'), - pytest.param(partial(sk_f1_score, average='macro'), f1_score, id='f1_score'), - pytest.param(partial(sk_fbeta_score, average='macro', beta=2), partial(fbeta_score, beta=2), id='fbeta_score'), - pytest.param(sk_confusion_matrix, confusion_matrix, id='confusion_matrix') +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [ + pytest.param(sk_accuracy, accuracy, False, id='accuracy'), + pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'), + pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'), + pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'), + pytest.param(partial(sk_f1_score, average='micro'), f1_score, False, id='f1_score'), + pytest.param(partial(sk_fbeta_score, average='micro', beta=2), + partial(fbeta_score, beta=2), False, id='fbeta_score'), + pytest.param(sk_confusion_matrix, confusion_matrix, False, id='confusion_matrix'), + pytest.param(sk_roc_curve, roc, True, id='roc'), + pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id='precision_recall_curve'), + pytest.param(sk_roc_auc_score, auroc, True, id='auroc') ]) -def test_against_sklearn(sklearn_metric, torch_metric): - """Compare PL metrics to sklearn version.""" +def test_against_sklearn(sklearn_metric, torch_metric, only_binary): + """Compare PL metrics to sklearn version. """ device = 'cuda' if torch.cuda.is_available() else 'cpu' - # iterate over different label counts in predictions and target - for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]: + # for metrics with only_binary=False, we try out different combinations of number + # of labels in pred and target (also test binary) + # for metrics with only_binary=True, target is always binary and pred will be + # (unnormalized) class probabilities + class_comb = [(5, 2)] if only_binary else [(10, 10), (5, 10), (10, 5), (2, 2)] + for n_cls_pred, n_cls_target in class_comb: pred = torch.randint(n_cls_pred, (300,), device=device) target = torch.randint(n_cls_target, (300,), device=device) sk_score = sklearn_metric(target.cpu().detach().numpy(), pred.cpu().detach().numpy()) - sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) pl_score = torch_metric(pred, target) - assert torch.allclose(sk_score, pl_score) + + # if multi output + if isinstance(sk_score, tuple): + sk_score = [torch.tensor(sk_s.copy(), dtype=torch.float, device=device) for sk_s in sk_score] + for sk_s, pl_s in zip(sk_score, pl_score): + assert torch.allclose(sk_s, pl_s.float()) + else: + sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) + assert torch.allclose(sk_score, pl_score) + + +@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted']) +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + pytest.param(sk_precision, precision, id='precision'), + pytest.param(sk_recall, recall, id='recall'), + pytest.param(sk_f1_score, f1_score, id='f1_score'), + pytest.param(partial(sk_fbeta_score, beta=2), partial(fbeta_score, beta=2), id='fbeta_score') +]) +def test_different_reduction_against_sklearn(class_reduction, sklearn_metric, torch_metric): + """ Test metrics where the class_reduction parameter have a correponding + value in sklearn """ + device = 'cuda' if torch.cuda.is_available() else 'cpu' + pred = torch.randint(10, (300,), device=device) + target = torch.randint(10, (300,), device=device) + sk_score = sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy(), + average=class_reduction) + sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) + pl_score = torch_metric(pred, target, class_reduction=class_reduction) + assert torch.allclose(sk_score, pl_score) def test_onehot(): @@ -121,15 +162,19 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect assert sup.item() == expected_support -@pytest.mark.parametrize(['pred', 'target', 'expected_tp', 'expected_fp', +@pytest.mark.parametrize(['pred', 'target', 'reduction', 'expected_tp', 'expected_fp', 'expected_tn', 'expected_fn', 'expected_support'], [ - pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), + pytest.param(torch.tensor([0., 2., 4., 4.]), torch.tensor([0., 4., 3., 4.]), 'none', + [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'none', [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]), - pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), - [1, 0, 0, 0, 1], [0, 0, 1, 0, 1], [3, 4, 3, 3, 1], [0, 0, 0, 1, 1], [1, 0, 0, 1, 2]) + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'sum', + torch.tensor(2), torch.tensor(2), torch.tensor(14), torch.tensor(2), torch.tensor(4)), + pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean', + torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8)) ]) -def test_stat_scores_multiclass(pred, target, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): - tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target) +def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): + tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction) assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) assert torch.allclose(torch.tensor(expected_fp).to(fp), fp) @@ -143,14 +188,18 @@ def test_multilabel_accuracy(): y1 = torch.tensor([[0, 1, 1], [1, 0, 1]]) y2 = torch.tensor([[0, 0, 1], [1, 0, 1]]) - assert torch.allclose(accuracy(y1, y2, reduction='none'), torch.tensor([2 / 3, 1.])) - assert torch.allclose(accuracy(y1, y1, reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, y2, reduction='none'), torch.tensor([1., 1.])) - assert torch.allclose(accuracy(y2, torch.logical_not(y2), reduction='none'), torch.tensor([0., 0.])) - assert torch.allclose(accuracy(y1, torch.logical_not(y1), reduction='none'), torch.tensor([0., 0.])) + assert torch.allclose(accuracy(y1, y2, class_reduction='none'), torch.tensor([2 / 3, 1.])) + assert torch.allclose(accuracy(y1, y1, class_reduction='none'), torch.tensor([1., 1.])) + assert torch.allclose(accuracy(y2, y2, class_reduction='none'), torch.tensor([1., 1.])) + assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.])) + assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.])) - with pytest.raises(RuntimeError): - accuracy(y2, torch.zeros_like(y2), reduction='none') + # num_classes does not match extracted number from input we expect a warning + with pytest.warns(RuntimeWarning, + match=r'You have set .* number of classes which is' + r' different from predicted (.*) and' + r' target (.*) number of classes'): + _ = accuracy(y2, torch.zeros_like(y2), num_classes=3) def test_accuracy(): @@ -178,14 +227,29 @@ def test_confusion_matrix(): cm = confusion_matrix(pred, target, normalize=True) assert torch.allclose(cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]])) + target = torch.LongTensor([0, 0, 0, 0, 0]) + pred = target.clone() + cm = confusion_matrix(pred, target, normalize=False, num_classes=3) + assert torch.allclose(cm, torch.tensor([[5., 0., 0.], [0., 0., 0.], [0., 0., 0.]])) + + # Example taken from https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html + target = torch.LongTensor([0] * 13 + [1] * 16 + [2] * 9) + pred = torch.LongTensor([0] * 13 + [1] * 10 + [2] * 15) + cm = confusion_matrix(pred, target, normalize=False, num_classes=3) + assert torch.allclose(cm, torch.tensor([[13., 0., 0.], [0., 10., 6.], [0., 0., 9.]])) + to_compare = cm / torch.tensor([[13.], [16.], [9.]]) + + cm = confusion_matrix(pred, target, normalize=True, num_classes=3) + assert torch.allclose(cm, to_compare) + @pytest.mark.parametrize(['pred', 'target', 'expected_prec', 'expected_rec'], [ pytest.param(torch.tensor([1., 0., 1., 0.]), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]), pytest.param(to_onehot(torch.tensor([1., 0., 1., 0.])), torch.tensor([0., 1., 1., 0.]), [0.5, 0.5], [0.5, 0.5]) ]) def test_precision_recall(pred, target, expected_prec, expected_rec): - prec = precision(pred, target, reduction='none') - rec = recall(pred, target, reduction='none') + prec = precision(pred, target, class_reduction='none') + rec = recall(pred, target, class_reduction='none') assert torch.allclose(torch.tensor(expected_prec).to(prec), prec) assert torch.allclose(torch.tensor(expected_rec).to(rec), rec) @@ -197,10 +261,10 @@ def test_precision_recall(pred, target, expected_prec, expected_rec): pytest.param([1., 0., 1., 0.], [0., 1., 1., 0.], 2, [0.5, 0.5]), ]) def test_fbeta_score(pred, target, beta, exp_score): - score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, reduction='none') + score = fbeta_score(torch.tensor(pred), torch.tensor(target), beta, class_reduction='none') assert torch.allclose(score, torch.tensor(exp_score)) - score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, reduction='none') + score = fbeta_score(to_onehot(torch.tensor(pred)), torch.tensor(target), beta, class_reduction='none') assert torch.allclose(score, torch.tensor(exp_score)) @@ -210,10 +274,10 @@ def test_fbeta_score(pred, target, beta, exp_score): pytest.param([1., 0., 1., 0.], [1., 0., 1., 0.], [1.0, 1.0]), ]) def test_f1_score(pred, target, exp_score): - score = f1_score(torch.tensor(pred), torch.tensor(target), reduction='none') + score = f1_score(torch.tensor(pred), torch.tensor(target), class_reduction='none') assert torch.allclose(score, torch.tensor(exp_score)) - score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), reduction='none') + score = f1_score(to_onehot(torch.tensor(pred)), torch.tensor(target), class_reduction='none') assert torch.allclose(score, torch.tensor(exp_score)) @@ -320,22 +384,102 @@ def test_dice_score(pred, target, expected): assert score == expected -@pytest.mark.parametrize(['half_ones', 'reduction', 'remove_bg', 'expected'], [ - pytest.param(False, 'none', False, torch.Tensor([1, 1, 1])), - pytest.param(False, 'elementwise_mean', False, torch.Tensor([1])), - pytest.param(False, 'none', True, torch.Tensor([1, 1])), - pytest.param(True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])), - pytest.param(True, 'elementwise_mean', False, torch.Tensor([0.5])), - pytest.param(True, 'none', True, torch.Tensor([0.5, 0.5])), +@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ + pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])), + pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])), + pytest.param(False, 'none', 0, torch.Tensor([1, 1])), + pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])), + pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])), + pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])), ]) -def test_iou(half_ones, reduction, remove_bg, expected): +def test_iou(half_ones, reduction, ignore_index, expected): pred = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) if half_ones: pred[:60] = 1 - iou_val = iou(pred, target, remove_bg=remove_bg, reduction=reduction) + iou_val = iou( + pred=pred, + target=target, + ignore_index=ignore_index, + reduction=reduction, + ) assert torch.allclose(iou_val, expected, atol=1e-9) +@pytest.mark.parametrize('metric', [auroc]) +def test_error_on_multiclass_input(metric): + """ check that these metrics raise an error if they are used for multiclass problems """ + pred = torch.randint(0, 10, (100, )) + target = torch.randint(0, 10, (100, )) + with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"): + _ = metric(pred, target) + + +# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see +# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our +# `absent_score`. +@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'absent_score', 'num_classes', 'expected'], [ + # Note that -1 is used as the absent_score in almost all tests here to distinguish it from the range of valid + # scores the function can return ([0., 1.] range, inclusive). + # 2 classes, class 0 is correct everywhere, class 1 is absent. + pytest.param([0], [0], None, -1., 2, [1., -1.]), + pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]), + # absent_score not applied if only class 0 is present and it's the only class. + pytest.param([0], [0], None, -1., 1, [1.]), + # 2 classes, class 1 is correct everywhere, class 0 is absent. + pytest.param([1], [1], None, -1., 2, [-1., 1.]), + pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]), + # When 0 index ignored, class 0 does not get a score (not even the absent_score). + pytest.param([1], [1], 0, -1., 2, [1.0]), + # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get absent_score. + pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]), + pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]), + # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get absent_score. + pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]), + pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get absent_score), class + # 2 is absent. + pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]), + # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get absent_score), class + # 2 is absent. + pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]), + # Sanity checks with absent_score of 1.0. + pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]), + pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]), +]) +def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes, expected): + iou_val = iou( + pred=torch.tensor(pred), + target=torch.tensor(target), + ignore_index=ignore_index, + absent_score=absent_score, + num_classes=num_classes, + reduction='none', + ) + assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) + + # example data taken from # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py +@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [ + # Ignoring an index outside of [0, num_classes-1] should have no effect. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]), + # Ignoring a valid index drops only that index from the result. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]), + # When reducing to mean or sum, the ignored index does not contribute to the output. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]), +]) +def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): + iou_val = iou( + pred=torch.tensor(pred), + target=torch.tensor(target), + ignore_index=ignore_index, + num_classes=num_classes, + reduction=reduction, + ) + assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py index 71d2b6f7735e1..aec54c1806715 100644 --- a/tests/metrics/functional/test_reduction.py +++ b/tests/metrics/functional/test_reduction.py @@ -1,7 +1,7 @@ import pytest import torch -from pytorch_lightning.metrics.functional.reduction import reduce +from pytorch_lightning.metrics.functional.reduction import reduce, class_reduce def test_reduce(): @@ -13,3 +13,18 @@ def test_reduce(): with pytest.raises(ValueError): reduce(start_tensor, 'error_reduction') + + +def test_class_reduce(): + num = torch.randint(1, 10, (100,)).float() + denom = torch.randint(10, 20, (100,)).float() + weights = torch.randint(1, 100, (100,)).float() + + assert torch.allclose(class_reduce(num, denom, weights, 'micro'), + torch.sum(num) / torch.sum(denom)) + assert torch.allclose(class_reduce(num, denom, weights, 'macro'), + torch.mean(num / denom)) + assert torch.allclose(class_reduce(num, denom, weights, 'weighted'), + torch.sum(num / denom * (weights / torch.sum(weights)))) + assert torch.allclose(class_reduce(num, denom, weights, 'none'), + num / denom) diff --git a/tests/metrics/functional/test_regression.py b/tests/metrics/functional/test_regression.py index 6aae9027bf3dd..49a79f9424f13 100644 --- a/tests/metrics/functional/test_regression.py +++ b/tests/metrics/functional/test_regression.py @@ -1,8 +1,17 @@ import numpy as np import pytest import torch -from skimage.metrics import peak_signal_noise_ratio as ski_psnr -from skimage.metrics import structural_similarity as ski_ssim +from functools import partial +from math import sqrt +from skimage.metrics import ( + peak_signal_noise_ratio as ski_psnr, + structural_similarity as ski_ssim +) +from sklearn.metrics import ( + mean_absolute_error as mae_sk, + mean_squared_error as mse_sk, + mean_squared_log_error as msle_sk +) from pytorch_lightning.metrics.functional import ( mae, @@ -14,6 +23,27 @@ ) +@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ + pytest.param(mae_sk, mae, id='mean_absolute_error'), + pytest.param(mse_sk, mse, id='mean_squared_error'), + pytest.param(partial(mse_sk, squared=False), rmse, id='root_mean_squared_error'), + pytest.param(lambda x, y: sqrt(msle_sk(x, y)), rmsle, id='root_mean_squared_log_error') +]) +def test_against_sklearn(sklearn_metric, torch_metric): + """Compare PL metrics to sklearn version.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # iterate over different label counts in predictions and target + pred = torch.rand(300, device=device) + target = torch.rand(300, device=device) + + sk_score = sklearn_metric(target.cpu().detach().numpy(), + pred.cpu().detach().numpy()) + sk_score = torch.tensor(sk_score, dtype=torch.float, device=device) + pl_score = torch_metric(pred, target) + assert torch.allclose(sk_score, pl_score) + + @pytest.mark.parametrize(['pred', 'target', 'expected'], [ pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25), pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0), @@ -45,8 +75,8 @@ def test_mae(pred, target, expected): @pytest.mark.parametrize(['pred', 'target', 'expected'], [ pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0), - pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.0207), - pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.2841), + pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.1438), + pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.5330), ]) def test_rmsle(pred, target, expected): score = rmsle(torch.tensor(pred), torch.tensor(target)) @@ -60,7 +90,7 @@ def test_rmsle(pred, target, expected): ]) def test_psnr_with_skimage(pred, target): score = psnr(pred=torch.tensor(pred), - target=torch.tensor(target)) + target=torch.tensor(target), data_range=3) sk_score = ski_psnr(np.array(pred), np.array(target), data_range=3) assert torch.allclose(score, torch.tensor(sk_score, dtype=torch.float), atol=1e-3) @@ -97,24 +127,25 @@ def test_psnr_against_sklearn(sklearn_metric, torch_metric): assert torch.allclose(sk_score, pl_score) -@pytest.mark.parametrize(['size', 'channel', 'plus', 'multichannel'], [ - pytest.param(16, 1, 0.125, False), - pytest.param(32, 1, 0.25, False), - pytest.param(48, 3, 0.5, True), - pytest.param(64, 4, 0.75, True), - pytest.param(128, 5, 1, True) +@pytest.mark.parametrize(['size', 'channel', 'coef', 'multichannel'], [ + pytest.param(16, 1, 0.9, False), + pytest.param(32, 3, 0.8, True), + pytest.param(48, 4, 0.7, True), + pytest.param(64, 5, 0.6, True) ]) -def test_ssim(size, channel, plus, multichannel): +def test_ssim(size, channel, coef, multichannel): device = "cuda" if torch.cuda.is_available() else "cpu" - pred = torch.rand(1, channel, size, size, device=device) - target = pred + plus - ssim_idx = ssim(pred, target) - np_pred = np.random.rand(size, size, channel) + pred = torch.rand(size, channel, size, size, device=device) + target = pred * coef + ssim_idx = ssim(pred, target, data_range=1.0) + np_pred = pred.permute(0, 2, 3, 1).cpu().numpy() if multichannel is False: - np_pred = np_pred[:, :, 0] - np_target = np.add(np_pred, plus) - sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True) - assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2) + np_pred = np_pred[:, :, :, 0] + np_target = np.multiply(np_pred, coef) + sk_ssim_idx = ski_ssim( + np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True, data_range=1.0 + ) + assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-4) ssim_idx = ssim(pred, pred) assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device)) diff --git a/tests/metrics/functional/test_self_supervised.py b/tests/metrics/functional/test_self_supervised.py new file mode 100644 index 0000000000000..1ef3b43f77b62 --- /dev/null +++ b/tests/metrics/functional/test_self_supervised.py @@ -0,0 +1,35 @@ +import pytest +import torch +from sklearn.metrics import pairwise + +from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity + + +@pytest.mark.parametrize('similarity', ['cosine', 'dot']) +@pytest.mark.parametrize('reduction', ['none', 'mean', 'sum']) +def test_against_sklearn(similarity, reduction): + """Compare PL metrics to sklearn version.""" + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + batch = torch.randn(5, 10, device=device) # 100 samples in 10 dimensions + + pl_dist = embedding_similarity(batch, similarity=similarity, + reduction=reduction, zero_diagonal=False) + + def sklearn_embedding_distance(batch, similarity, reduction): + + metric_func = {'cosine': pairwise.cosine_similarity, + 'dot': pairwise.linear_kernel}[similarity] + + dist = metric_func(batch, batch) + if reduction == 'mean': + return dist.mean(axis=-1) + if reduction == 'sum': + return dist.sum(axis=-1) + return dist + + sk_dist = sklearn_embedding_distance(batch.cpu().detach().numpy(), + similarity=similarity, reduction=reduction) + sk_dist = torch.tensor(sk_dist, dtype=torch.float, device=device) + + assert torch.allclose(sk_dist, pl_dist)