diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index bca93416e7b58..6a7d8e5bb9330 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -126,8 +126,8 @@ Example implementation: from pytorch_lightning.metrics import Metric class MyAccuracy(Metric): - def __init__(self, ddp_sync_on_step=False): - super().__init__(ddp_sync_on_step=ddp_sync_on_step) + def __init__(self, dist_sync_on_step=False): + super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") @@ -157,6 +157,24 @@ Accuracy .. autoclass:: pytorch_lightning.metrics.classification.Accuracy :noindex: +Precision +^^^^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.Precision + :noindex: + +Recall +^^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.Recall + :noindex: + +Fbeta +^^^^^ + +.. autoclass:: pytorch_lightning.metrics.classification.Fbeta + :noindex: + Regression Metrics ------------------ diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 6a20c6a0b1771..263ee52b421ee 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -1,6 +1,12 @@ from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.metrics.classification import ( + Accuracy, + Precision, + Recall, + Fbeta +) + from pytorch_lightning.metrics.regression import ( MeanSquaredError, MeanAbsoluteError, diff --git a/pytorch_lightning/metrics/classification/__init__.py b/pytorch_lightning/metrics/classification/__init__.py index 45e66603b2465..bcac66b40ef93 100644 --- a/pytorch_lightning/metrics/classification/__init__.py +++ b/pytorch_lightning/metrics/classification/__init__.py @@ -1 +1,3 @@ from pytorch_lightning.metrics.classification.accuracy import Accuracy +from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall +from pytorch_lightning.metrics.classification.f_beta import Fbeta diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 50751aa73be51..617af31d12f02 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -31,7 +31,7 @@ class Accuracy(Metric): Threshold value for binary or multi-label logits. default: 0.5 compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True - ddp_sync_on_step: + dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. default: False process_group: @@ -52,12 +52,12 @@ def __init__( self, threshold: float = 0.5, compute_on_step: bool = True, - ddp_sync_on_step: bool = False, + dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): super().__init__( compute_on_step=compute_on_step, - ddp_sync_on_step=ddp_sync_on_step, + dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) @@ -79,7 +79,6 @@ def _input_format(self, preds: torch.Tensor, target: torch.Tensor): if len(preds.shape) == len(target.shape) and preds.dtype == torch.float: # binary or multilabel probablities preds = (preds >= self.threshold).long() - return preds, target def update(self, preds: torch.Tensor, target: torch.Tensor): diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py new file mode 100644 index 0000000000000..6a0daa386cf06 --- /dev/null +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -0,0 +1,119 @@ +import math +import functools +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union +from collections.abc import Mapping, Sequence +from collections import namedtuple + +import torch +from torch import nn +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.classification.precision_recall import _input_format +from pytorch_lightning.metrics.utils import METRIC_EPS + + +class Fbeta(Metric): + """ + Computes f_beta metric. + + Works with binary, multiclass, and multilabel data. + Accepts logits from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + + Forward accepts + + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. + This is the case for binary and multi-label logits. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + beta: Beta coefficient in the F measure. + threshold: + Threshold value for binary or multi-label logits. default: 0.5 + + average: + * `'micro'` computes metric globally + * `'macro'` computes metric for each class and then takes the mean + + multilabel: If predictions are from multilabel classification. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + + >>> from pytorch_lightning.metrics import Fbeta + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> f_beta = Fbeta(num_classes=3, beta=0.5) + >>> f_beta(preds, target) + tensor(0.3333) + + """ + def __init__( + self, + num_classes: int = 1, + beta: float = 1., + threshold: float = 0.5, + average: str = 'micro', + multilabel: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.beta = beta + self.threshold = threshold + self.average = average + self.multilabel = multilabel + + assert self.average in ('micro', 'macro'), \ + "average passed to the function must be either `micro` or `macro`" + + self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel) + + self.true_positives += torch.sum(preds * target, dim=1) + self.predicted_positives += torch.sum(preds, dim=1) + self.actual_positives += torch.sum(target, dim=1) + + def compute(self): + """ + Computes accuracy over state. + """ + if self.average == 'micro': + precision = self.true_positives.sum().float() / (self.predicted_positives.sum() + METRIC_EPS) + recall = self.true_positives.sum().float() / (self.actual_positives.sum() + METRIC_EPS) + + return (1 + self.beta ** 2) * (precision * recall) / (self.beta ** 2 * precision + recall) + elif self.average == 'macro': + precision = self.true_positives.float() / (self.predicted_positives + METRIC_EPS) + recall = self.true_positives.float() / (self.actual_positives + METRIC_EPS) + + return ((1 + self.beta ** 2) * (precision * recall) / (self.beta ** 2 * precision + recall)).mean() diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py new file mode 100644 index 0000000000000..6a52317b86f9f --- /dev/null +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -0,0 +1,224 @@ +import math +import functools +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union +from collections.abc import Mapping, Sequence +from collections import namedtuple + +import torch +from torch import nn +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.utils import to_onehot, METRIC_EPS + + +def _input_format(num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold=0.5, multilabel=False): + if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1): + raise ValueError( + "preds and target must have same number of dimensions, or one additional dimension for preds" + ) + + if len(preds.shape) == len(target.shape) + 1: + # multi class probabilites + preds = torch.argmax(preds, dim=1) + + if len(preds.shape) == len(target.shape) and preds.dtype == torch.long and num_classes > 1 and not multilabel: + # multi-class + preds = to_onehot(preds, num_classes=num_classes) + target = to_onehot(target, num_classes=num_classes) + + elif len(preds.shape) == len(target.shape) and preds.dtype == torch.float: + # binary or multilabel probablities + preds = (preds >= threshold).long() + + # transpose class as first dim and reshape + if len(preds.shape) > 1: + preds = preds.transpose(1, 0) + target = target.transpose(1, 0) + + return preds.reshape(num_classes, -1), target.reshape(num_classes, -1) + + +class Precision(Metric): + """ + Computes the precision metric. + + Works with binary, multiclass, and multilabel data. + Accepts logits from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + + Forward accepts + + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. + This is the case for binary and multi-label logits. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + beta: Beta coefficient in the F measure. + threshold: + Threshold value for binary or multi-label logits. default: 0.5 + + average: + * `'micro'` computes metric globally + * `'macro'` computes metric for each class and then takes the mean + + multilabel: If predictions are from multilabel classification. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + + >>> from pytorch_lightning.metrics import Precision + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> precision = Precision(num_classes=3) + >>> precision(preds, target) + tensor(0.3333) + + """ + def __init__( + self, + num_classes: int = 1, + threshold: float = 0.5, + average: str = 'micro', + multilabel: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.threshold = threshold + self.average = average + self.multilabel = multilabel + + assert self.average in ('micro', 'macro'), \ + "average passed to the function must be either `micro` or `macro`" + + self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel) + + # multiply because we are counting (1, 1) pair for true positives + self.true_positives += torch.sum(preds * target, dim=1) + self.predicted_positives += torch.sum(preds, dim=1) + + def compute(self): + if self.average == 'micro': + return self.true_positives.sum().float() / (self.predicted_positives.sum() + METRIC_EPS) + elif self.average == 'macro': + return (self.true_positives.float() / (self.predicted_positives + METRIC_EPS)).mean() + + +class Recall(Metric): + """ + Computes the recall metric. + + Works with binary, multiclass, and multilabel data. + Accepts logits from a model output or integer class values in prediction. + Works with multi-dimensional preds and target. + + Forward accepts + + - ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes + - ``target`` (long tensor): ``(N, ...)`` + + If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. + This is the case for binary and multi-label logits. + + If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. + + Args: + num_classes: Number of classes in the dataset. + beta: Beta coefficient in the F measure. + threshold: + Threshold value for binary or multi-label logits. default: 0.5 + + average: + * `'micro'` computes metric globally + * `'macro'` computes metric for each class and then takes the mean + + multilabel: If predictions are from multilabel classification. + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + + >>> from pytorch_lightning.metrics import Recall + >>> target = torch.tensor([0, 1, 2, 0, 1, 2]) + >>> preds = torch.tensor([0, 2, 1, 0, 0, 1]) + >>> recall = Recall(num_classes=3) + >>> recall(preds, target) + tensor(0.3333) + + """ + def __init__( + self, + num_classes: int = 1, + threshold: float = 0.5, + average: str = 'micro', + multilabel: bool = False, + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + ) + + self.num_classes = num_classes + self.threshold = threshold + self.average = average + self.multilabel = multilabel + + assert self.average in ('micro', 'macro'), \ + "average passed to the function must be either `micro` or `macro`" + + self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + self.add_state("actual_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel) + + # multiply because we are counting (1, 1) pair for true positives + self.true_positives += torch.sum(preds * target, dim=1) + self.actual_positives += torch.sum(target, dim=1) + + def compute(self): + """ + Computes accuracy over state. + """ + if self.average == 'micro': + return self.true_positives.sum().float() / (self.actual_positives.sum() + METRIC_EPS) + elif self.average == 'macro': + return (self.true_positives.float() / (self.actual_positives + METRIC_EPS)).mean() diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 5a4da899a8591..b30229da1fa18 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -37,7 +37,7 @@ class Metric(nn.Module, ABC): Args: compute_on_step: Forward only calls ``update()`` and returns None if this is set to False. default: True - ddp_sync_on_step: + dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. default: False process_group: @@ -46,12 +46,12 @@ class Metric(nn.Module, ABC): def __init__( self, compute_on_step: bool = True, - ddp_sync_on_step: bool = False, + dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): super().__init__() - self.ddp_sync_on_step = ddp_sync_on_step + self.dist_sync_on_step = dist_sync_on_step self.compute_on_step = compute_on_step self.process_group = process_group self._to_sync = True @@ -133,7 +133,7 @@ def forward(self, *args, **kwargs): self._forward_cache = None if self.compute_on_step: - self._to_sync = self.ddp_sync_on_step + self._to_sync = self.dist_sync_on_step # save context before switch self._cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 97aaab1511d5d..b20d26a22c26e 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -3,6 +3,7 @@ from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.metrics.utils import METRIC_EPS class ExplainedVariance(Metric): @@ -29,7 +30,7 @@ class ExplainedVariance(Metric): compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True - ddp_sync_on_step: + dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. default: False process_group: @@ -55,12 +56,12 @@ def __init__( self, multioutput: str = 'uniform_average', compute_on_step: bool = True, - ddp_sync_on_step: bool = False, + dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): super().__init__( compute_on_step=compute_on_step, - ddp_sync_on_step=ddp_sync_on_step, + dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 8b6db9dd16c3d..7f4899d705428 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -11,7 +11,7 @@ class MeanAbsoluteError(Metric): Args: compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True - ddp_sync_on_step: + dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. default: False process_group: @@ -30,12 +30,12 @@ class MeanAbsoluteError(Metric): def __init__( self, compute_on_step: bool = True, - ddp_sync_on_step: bool = False, + dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): super().__init__( compute_on_step=compute_on_step, - ddp_sync_on_step=ddp_sync_on_step, + dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index 79ff25b7ac826..cf9d5a1794fb3 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -11,7 +11,7 @@ class MeanSquaredError(Metric): Args: compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True - ddp_sync_on_step: + dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. default: False process_group: @@ -31,12 +31,12 @@ class MeanSquaredError(Metric): def __init__( self, compute_on_step: bool = True, - ddp_sync_on_step: bool = False, + dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): super().__init__( compute_on_step=compute_on_step, - ddp_sync_on_step=ddp_sync_on_step, + dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index 5467104a324ff..0d4a2ae5c1628 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -11,7 +11,7 @@ class MeanSquaredLogError(Metric): Args: compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True - ddp_sync_on_step: + dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` before returning the value at the step. default: False process_group: @@ -31,12 +31,12 @@ class MeanSquaredLogError(Metric): def __init__( self, compute_on_step: bool = True, - ddp_sync_on_step: bool = False, + dist_sync_on_step: bool = False, process_group: Optional[Any] = None, ): super().__init__( compute_on_step=compute_on_step, - ddp_sync_on_step=ddp_sync_on_step, + dist_sync_on_step=dist_sync_on_step, process_group=process_group, ) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index 850b3858b0848..fa8b35766d3c9 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -3,6 +3,9 @@ from typing import Any, Callable, Optional, Union +METRIC_EPS = 1e-6 + + def dim_zero_cat(x): return torch.cat(x, dim=0) @@ -17,3 +20,31 @@ def dim_zero_mean(x): def _flatten(x): return [item for sublist in x for item in sublist] + + +def to_onehot( + tensor: torch.Tensor, + num_classes: int, +) -> torch.Tensor: + """ + Converts a dense label tensor to one-hot format + + Args: + tensor: dense label tensor, with shape [N, d1, d2, ...] + num_classes: number of classes C + + Output: + A sparse label tensor with shape [N, C, d1, d2, ...] + + Example: + >>> x = torch.tensor([1, 2, 3]) + >>> to_onehot(x, num_classes=4) + tensor([[0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + """ + dtype, device, shape = tensor.dtype, tensor.device, tensor.shape + tensor_onehot = torch.zeros(shape[0], num_classes, *shape[1:], + dtype=dtype, device=device) + index = tensor.long().unsqueeze(1).expand_as(tensor_onehot) + return tensor_onehot.scatter_(1, index, 1.0) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index b9cad945bd6c1..e61ac5b49dd53 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -39,7 +39,7 @@ def _ddp_test_fn(rank, worldsize): metric_b = DummyMetric() metric_c = DummyMetric() - # ddp_sync_on_step is False by default + # dist_sync_on_step is False by default result = Result() for epoch in range(3): diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index e19602e9982b1..0e5477c52fb8f 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -3,47 +3,35 @@ import torch import numpy as np from collections import namedtuple +from functools import partial from pytorch_lightning.metrics.classification.accuracy import Accuracy from sklearn.metrics import accuracy_score -from tests.metrics.utils import compute_batch, NUM_BATCHES, BATCH_SIZE +from tests.metrics.utils import compute_batch, setup_ddp +from tests.metrics.utils import THRESHOLD + +from tests.metrics.classification.utils import ( + _binary_prob_inputs, + _binary_inputs, + _multilabel_prob_inputs, + _multilabel_inputs, + _multiclass_prob_inputs, + _multiclass_inputs, + _multidim_multiclass_prob_inputs, + _multidim_multiclass_inputs, +) torch.manual_seed(42) -# global vars -num_classes = 5 -threshold = 0.5 -extra_dim = 3 - -Input = namedtuple('Input', ["preds", "target"]) - - -def test_accuracy_invalid_shape(): - with pytest.raises(ValueError): - acc = Accuracy() - acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3)) - - -_binary_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) -) - def _binary_prob_sk_metric(preds, target): - sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) sk_target = target.view(-1).numpy() return accuracy_score(y_true=sk_target, y_pred=sk_preds) -_binary_inputs = Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)) -) - - def _binary_sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -51,25 +39,13 @@ def _binary_sk_metric(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) -_multilabel_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)) -) - - def _multilabel_prob_sk_metric(preds, target): - sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) sk_target = target.view(-1).numpy() return accuracy_score(y_true=sk_target, y_pred=sk_preds) -_multilabel_inputs = Input( - preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)), - target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, num_classes)) -) - - def _multilabel_sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -77,12 +53,6 @@ def _multilabel_sk_metric(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) -_multiclass_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes), - target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)) -) - - def _multiclass_prob_sk_metric(preds, target): sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() sk_target = target.view(-1).numpy() @@ -90,12 +60,6 @@ def _multiclass_prob_sk_metric(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) -_multiclass_inputs = Input( - preds=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)), - target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE)) -) - - def _multiclass_sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -103,12 +67,6 @@ def _multiclass_sk_metric(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) -_multidim_multiclass_prob_inputs = Input( - preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_classes, extra_dim), - target=torch.randint(high=num_classes, size=(NUM_BATCHES, BATCH_SIZE, extra_dim)) -) - - def _multidim_multiclass_prob_sk_metric(preds, target): sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() sk_target = target.view(-1).numpy() @@ -116,12 +74,6 @@ def _multidim_multiclass_prob_sk_metric(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) -_multidim_multiclass_inputs = Input( - preds=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE)), - target=torch.randint(high=num_classes, size=(NUM_BATCHES, extra_dim, BATCH_SIZE)) -) - - def _multidim_multiclass_sk_metric(preds, target): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() @@ -129,8 +81,14 @@ def _multidim_multiclass_sk_metric(preds, target): return accuracy_score(y_true=sk_target, y_pred=sk_preds) +def test_accuracy_invalid_shape(): + with pytest.raises(ValueError): + acc = Accuracy() + acc.update(preds=torch.rand(1), target=torch.rand(1, 2, 3)) + + @pytest.mark.parametrize("ddp", [True, False]) -@pytest.mark.parametrize("ddp_sync_on_step", [True, False]) +@pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("preds, target, sk_metric", [ (_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric), (_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric), @@ -149,5 +107,13 @@ def _multidim_multiclass_sk_metric(preds, target): _multidim_multiclass_sk_metric ) ]) -def test_accuracy(ddp, ddp_sync_on_step, preds, target, sk_metric): - compute_batch(preds, target, Accuracy, sk_metric, ddp_sync_on_step, ddp, metric_args={"threshold": threshold}) +def test_accuracy(ddp, dist_sync_on_step, preds, target, sk_metric): + compute_batch( + preds, + target, + Accuracy, + sk_metric, + dist_sync_on_step, + ddp, + metric_args={"threshold": THRESHOLD} + ) diff --git a/tests/metrics/classification/test_f_beta.py b/tests/metrics/classification/test_f_beta.py new file mode 100644 index 0000000000000..d339a032554b6 --- /dev/null +++ b/tests/metrics/classification/test_f_beta.py @@ -0,0 +1,145 @@ +import os +import pytest +import torch +import numpy as np +from collections import namedtuple + +from functools import partial + +from pytorch_lightning.metrics import Fbeta +from sklearn.metrics import fbeta_score + +from tests.metrics.utils import compute_batch, setup_ddp +from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, NUM_CLASSES, THRESHOLD + +from tests.metrics.classification.utils import ( + _binary_prob_inputs, + _binary_inputs, + _multilabel_prob_inputs, + _multilabel_inputs, + _multiclass_prob_inputs, + _multiclass_inputs, + _multidim_multiclass_prob_inputs, + _multidim_multiclass_inputs, +) + +torch.manual_seed(42) + + +def _binary_prob_sk_metric(preds, target, average='micro', beta=1.): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta) + + +def _binary_sk_metric(preds, target, average='micro', beta=1.): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average='binary', beta=beta) + + +def _multilabel_prob_sk_metric(preds, target, average='micro', beta=1.): + sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1, NUM_CLASSES).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _multilabel_sk_metric(preds, target, average='micro', beta=1.): + sk_preds = preds.view(-1, NUM_CLASSES).numpy() + sk_target = target.view(-1, NUM_CLASSES).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _multiclass_prob_sk_metric(preds, target, average='micro', beta=1.): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _multiclass_sk_metric(preds, target, average='micro', beta=1.): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _multidim_multiclass_prob_sk_metric(preds, target, average='micro', beta=1.): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +def _multidim_multiclass_sk_metric(preds, target, average='micro', beta=1.): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return fbeta_score(y_true=sk_target, y_pred=sk_preds, average=average, beta=beta) + + +@pytest.mark.parametrize("ddp", [True, False]) +@pytest.mark.parametrize("dist_sync_on_step", [True, False]) +@pytest.mark.parametrize("average", ['micro', 'macro']) +@pytest.mark.parametrize("preds, target, sk_metric, num_classes, multilabel", [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1, False), + (_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric, 1, False), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric, NUM_CLASSES, True), + (_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric, NUM_CLASSES, True), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric, NUM_CLASSES, False), + (_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric, NUM_CLASSES, False), + ( + _multidim_multiclass_prob_inputs.preds, + _multidim_multiclass_prob_inputs.target, + _multidim_multiclass_prob_sk_metric, + NUM_CLASSES, + False + ), + ( + _multidim_multiclass_inputs.preds, + _multidim_multiclass_inputs.target, + _multidim_multiclass_sk_metric, + NUM_CLASSES, + False + ) +]) +@pytest.mark.parametrize( + "metric_class, beta", + [ + (Fbeta, 0.5), + (Fbeta, 1.), + ], +) +def test_fbeta( + ddp, + dist_sync_on_step, + preds, + target, + sk_metric, + metric_class, + beta, + num_classes, + multilabel, + average +): + compute_batch( + preds, + target, + metric_class, + partial(sk_metric, average=average, beta=beta), + dist_sync_on_step, + ddp, + metric_args={ + "beta": beta, + "num_classes": num_classes, + "average": average, + "multilabel": multilabel, + "threshold": THRESHOLD + }, + check_dist_sync_on_step=False, + check_batch=False, + ) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py new file mode 100644 index 0000000000000..f2ccd7a0ef71b --- /dev/null +++ b/tests/metrics/classification/test_precision_recall.py @@ -0,0 +1,144 @@ +import os +import pytest +import torch +import numpy as np +from collections import namedtuple + +from functools import partial + +from pytorch_lightning.metrics import Precision, Recall +from sklearn.metrics import precision_score, recall_score + +from tests.metrics.utils import compute_batch, setup_ddp +from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE, NUM_CLASSES, THRESHOLD + +from tests.metrics.classification.utils import ( + _binary_prob_inputs, + _binary_inputs, + _multilabel_prob_inputs, + _multilabel_inputs, + _multiclass_prob_inputs, + _multiclass_inputs, + _multidim_multiclass_prob_inputs, + _multidim_multiclass_inputs, +) + +torch.manual_seed(42) + + +def _binary_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = (preds.view(-1).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1).numpy() + + return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary') + + +def _binary_sk_metric(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_fn(y_true=sk_target, y_pred=sk_preds, average='binary') + + +def _multilabel_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = (preds.view(-1, NUM_CLASSES).numpy() >= THRESHOLD).astype(np.uint8) + sk_target = target.view(-1, NUM_CLASSES).numpy() + + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _multilabel_sk_metric(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = preds.view(-1, NUM_CLASSES).numpy() + sk_target = target.view(-1, NUM_CLASSES).numpy() + + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _multiclass_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 1).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _multiclass_sk_metric(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _multidim_multiclass_prob_sk_metric(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = torch.argmax(preds, dim=len(preds.shape) - 2).view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + + +def _multidim_multiclass_sk_metric(preds, target, sk_fn=precision_score, average='micro'): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + + return sk_fn(y_true=sk_target, y_pred=sk_preds, average=average) + + +@pytest.mark.parametrize("ddp", [True, False]) +@pytest.mark.parametrize("dist_sync_on_step", [True, False]) +@pytest.mark.parametrize("average", ['micro', 'macro']) +@pytest.mark.parametrize("preds, target, sk_metric, num_classes, multilabel", [ + (_binary_prob_inputs.preds, _binary_prob_inputs.target, _binary_prob_sk_metric, 1, False), + (_binary_inputs.preds, _binary_inputs.target, _binary_sk_metric, 1, False), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _multilabel_prob_sk_metric, NUM_CLASSES, True), + (_multilabel_inputs.preds, _multilabel_inputs.target, _multilabel_sk_metric, NUM_CLASSES, True), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _multiclass_prob_sk_metric, NUM_CLASSES, False), + (_multiclass_inputs.preds, _multiclass_inputs.target, _multiclass_sk_metric, NUM_CLASSES, False), + ( + _multidim_multiclass_prob_inputs.preds, + _multidim_multiclass_prob_inputs.target, + _multidim_multiclass_prob_sk_metric, + NUM_CLASSES, + False + ), + ( + _multidim_multiclass_inputs.preds, + _multidim_multiclass_inputs.target, + _multidim_multiclass_sk_metric, + NUM_CLASSES, + False + ) +]) +@pytest.mark.parametrize( + "metric_class, sk_fn", + [ + (Precision, precision_score), + (Recall, recall_score), + ], +) +def test_precision_recall( + ddp, + dist_sync_on_step, + preds, + target, + sk_metric, + metric_class, + sk_fn, + num_classes, + multilabel, + average +): + compute_batch( + preds, + target, + metric_class, + partial(sk_metric, sk_fn=sk_fn, average=average), + dist_sync_on_step, + ddp, + metric_args={ + "num_classes": num_classes, + "average": average, + "multilabel": multilabel, + "threshold": THRESHOLD + }, + check_dist_sync_on_step=False if average == 'macro' else True, + check_batch=False if average == 'macro' else True, + ) diff --git a/tests/metrics/classification/utils.py b/tests/metrics/classification/utils.py new file mode 100644 index 0000000000000..0aafde9448311 --- /dev/null +++ b/tests/metrics/classification/utils.py @@ -0,0 +1,64 @@ +import os +import pytest +import numpy as np +import torch + +from collections import namedtuple +from tests.metrics.utils import ( + NUM_BATCHES, + NUM_PROCESSES, + BATCH_SIZE, + NUM_CLASSES, + EXTRA_DIM, + THRESHOLD +) + +Input = namedtuple('Input', ["preds", "target"]) + + +_binary_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE)) +) + + +_binary_inputs = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE,)) +) + + +_multilabel_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) +) + + +_multilabel_inputs = Input( + preds=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)), + target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES)) +) + + +_multiclass_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) +) + + +_multiclass_inputs = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE)) +) + + +_multidim_multiclass_prob_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES, EXTRA_DIM), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM)) +) + + +_multidim_multiclass_inputs = Input( + preds=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)), + target=torch.randint(high=NUM_CLASSES, size=(NUM_BATCHES, EXTRA_DIM, BATCH_SIZE)) +) diff --git a/tests/metrics/regression/test_explained_variance.py b/tests/metrics/regression/test_explained_variance.py index de92fa524712b..a01efe5973f7f 100644 --- a/tests/metrics/regression/test_explained_variance.py +++ b/tests/metrics/regression/test_explained_variance.py @@ -38,7 +38,7 @@ def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): @pytest.mark.parametrize("ddp", [True, False]) -@pytest.mark.parametrize("ddp_sync_on_step", [True, False]) +@pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) @pytest.mark.parametrize( "preds, target, sk_metric", @@ -47,13 +47,13 @@ def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score): (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric), ], ) -def test_explained_variance(ddp, ddp_sync_on_step, multioutput, preds, target, sk_metric): +def test_explained_variance(ddp, dist_sync_on_step, multioutput, preds, target, sk_metric): compute_batch( preds, target, ExplainedVariance, partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)), - ddp_sync_on_step, + dist_sync_on_step, ddp, metric_args=dict(multioutput=multioutput), ) diff --git a/tests/metrics/regression/test_mean_error.py b/tests/metrics/regression/test_mean_error.py index 03180bda9e4df..21bedefe69829 100644 --- a/tests/metrics/regression/test_mean_error.py +++ b/tests/metrics/regression/test_mean_error.py @@ -38,7 +38,7 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): @pytest.mark.parametrize("ddp", [True, False]) -@pytest.mark.parametrize("ddp_sync_on_step", [True, False]) +@pytest.mark.parametrize("dist_sync_on_step", [True, False]) @pytest.mark.parametrize( "preds, target, sk_metric", [ @@ -54,8 +54,8 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error): (MeanSquaredLogError, mean_squared_log_error), ], ) -def test_mean_error(ddp, ddp_sync_on_step, preds, target, sk_metric, metric_class, sk_fn): - compute_batch(preds, target, metric_class, partial(sk_metric, sk_fn=sk_fn), ddp_sync_on_step, ddp) +def test_mean_error(ddp, dist_sync_on_step, preds, target, sk_metric, metric_class, sk_fn): + compute_batch(preds, target, metric_class, partial(sk_metric, sk_fn=sk_fn), dist_sync_on_step, ddp) @pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError]) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 29671be816449..9c6440ea2c6db 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -12,6 +12,9 @@ NUM_PROCESSES = 2 NUM_BATCHES = 10 BATCH_SIZE = 16 +NUM_CLASSES = 5 +EXTRA_DIM = 3 +THRESHOLD = 0.5 def setup_ddp(rank, world_size): @@ -21,20 +24,23 @@ def setup_ddp(rank, world_size): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) -def _compute_batch(rank: int, - preds: torch.Tensor, - target: torch.Tensor, - metric_class: Metric, - sk_metric: Callable, - ddp_sync_on_step: bool, - worldsize: int = 1, - metric_args: dict = {} - ): +def _compute_batch( + rank: int, + preds: torch.Tensor, + target: torch.Tensor, + metric_class: Metric, + sk_metric: Callable, + dist_sync_on_step: bool, + worldsize: int = 1, + metric_args: dict = {}, + check_dist_sync_on_step: bool = True, + check_batch: bool = True, +): """ Utility function doing the actual comparison between lightning metric and reference metric """ # Instanciate lightning metric - metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args) + metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) # verify metrics work after being loaded from pickled state pickled_metric = pickle.dumps(metric) @@ -47,15 +53,19 @@ def _compute_batch(rank: int, for i in range(rank, NUM_BATCHES, worldsize): batch_result = metric(preds[i], target[i]) - if metric.ddp_sync_on_step: + if metric.dist_sync_on_step: if rank == 0: ddp_preds = torch.stack([preds[i + r] for r in range(worldsize)]) ddp_target = torch.stack([target[i + r] for r in range(worldsize)]) sk_batch_result = sk_metric(ddp_preds, ddp_target) - assert np.allclose(batch_result.numpy(), sk_batch_result) + # assert for dist_sync_on_step + if check_dist_sync_on_step: + assert np.allclose(batch_result.numpy(), sk_batch_result) else: sk_batch_result = sk_metric(preds[i], target[i]) - assert np.allclose(batch_result.numpy(), sk_batch_result) + # assert for batch + if check_batch: + assert np.allclose(batch_result.numpy(), sk_batch_result) # check on all batches on all ranks result = metric.compute() @@ -65,17 +75,21 @@ def _compute_batch(rank: int, total_target = torch.stack([target[i] for i in range(NUM_BATCHES)]) sk_result = sk_metric(total_preds, total_target) + # assert after aggregation assert np.allclose(result.numpy(), sk_result) -def compute_batch(preds: torch.Tensor, - target: torch.Tensor, - metric_class: Metric, - sk_metric: Callable, - ddp_sync_on_step: bool, - ddp: bool = False, - metric_args: dict = {} - ): +def compute_batch( + preds: torch.Tensor, + target: torch.Tensor, + metric_class: Metric, + sk_metric: Callable, + dist_sync_on_step: bool, + ddp: bool = False, + metric_args: dict = {}, + check_dist_sync_on_step: bool = True, + check_batch: bool = True, +): """ Utility function for comparing the result between a lightning class metric and another metric (often sklearns) @@ -84,19 +98,42 @@ def compute_batch(preds: torch.Tensor, target: target tensor metric_class: lightning metric class to test sk_metric: function to compare with - ddp_sync_on_step: bool, determine if values should be reduce on step + dist_sync_on_step: bool, determine if values should be reduce on step ddp: bool, determine if test should run in ddp mode metric_args: dict, additional kwargs that are use when instanciating the lightning metric + check_dist_sync_on_step: assert for dist_sync_on_step + check_batch: assert for each batch """ if ddp: if sys.platform == "win32": pytest.skip("DDP not supported on windows") torch.multiprocessing.spawn( - _compute_batch, args=(preds, target, metric_class, sk_metric, ddp_sync_on_step, NUM_PROCESSES, metric_args), + _compute_batch, args=( + preds, + target, + metric_class, + sk_metric, + dist_sync_on_step, + NUM_PROCESSES, + metric_args, + check_dist_sync_on_step, + check_batch, + ), nprocs=NUM_PROCESSES ) else: # first args: rank, last args: world size - _compute_batch(0, preds, target, metric_class, sk_metric, ddp_sync_on_step, 1, metric_args) + _compute_batch( + rank=0, + preds=preds, + target=target, + metric_class=metric_class, + sk_metric=sk_metric, + dist_sync_on_step=dist_sync_on_step, + worldsize=1, + metric_args=metric_args, + check_dist_sync_on_step=check_dist_sync_on_step, + check_batch=check_batch, + )