From b5bc732ee74b685791185714c19867c5d850c389 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 10 Nov 2020 12:32:05 +0100 Subject: [PATCH 01/12] add r2metric --- .../metrics/functional/r2score.py | 90 +++++++++++++ .../metrics/regression/r2score.py | 121 ++++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 pytorch_lightning/metrics/functional/r2score.py create mode 100644 pytorch_lightning/metrics/regression/r2score.py diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py new file mode 100644 index 0000000000000..1f0b07d0ed82a --- /dev/null +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -0,0 +1,90 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import torch + +from pytorch_lightning.metrics.utils import _check_same_shape + + +def _r2score_update( + preds: torch.tensor, + target: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + _check_same_shape(preds, target) + if len(preds.shape)>2: + raise ValueError('Expected both prediction and target to 1D or 2D tensors,' + f' but recevied tensors with dimension {preds.shape}') + + sum_error = torch.sum(target, dim=0) + sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) + residual = torch.sum(torch.pow(target-preds, 2.0), dim=0) + total = torch.sum(torch.ones_like(target), dim=0) + + return sum_squared_error, sum_error, residual, total + + +def _r2score_compute(sum_squared_error: torch.Tensor, + sum_error: torch.Tensor, + residual: torch.Tensor, + total: torch.Tensor, + multioutput: str = "uniform_average") -> torch.Tensor: + mean_error = sum_error / total + diff = sum_squared_error - sum_error * mean_error + raw_scores = 1 - (residual / diff) + + if multioutput == "raw_values": + return raw_scores + if multioutput == "uniform_average": + return torch.mean(raw_scores) + if multioutput == "variance_weighted": + diff_sum = torch.sum(diff) + return torch.sum(diff / diff_sum * raw_scores) + + raise ValueError('Argument `multioutput` must be either `raw_values`,' + f' `uniform_average` or `variance_weighted`. Received {multioutput}.') + +def r2score( + preds: torch.Tensor, + target: torch.Tensor, + multioutput: str = "uniform_average", +) -> torch.Tensor: + """ + Computes r2 score also known as coefficient of determination + + Args: + pred: estimated labels + target: ground truth labels + multioutput: Defines aggregation in the case of multiple output scores. Can be one + of the following strings (default is `'uniform_average'`.): + + * `'raw_values'` returns full set of scores + * `'uniform_average'` scores are uniformly averaged + * `'variance_weighted'` scores are weighted by their individual variances + + Example: + + >>> from pytorch_lightning.metrics.functional import r2score + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> r2score(preds, target) + tensor(0.9486) + + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> r2score(preds, target, multioutput='raw_values') + tensor([0.9654, 0.9082]) + """ + sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) + return _r2score_compute(sum_squared_error, sum_error, residual, total, multioutput) \ No newline at end of file diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py new file mode 100644 index 0000000000000..a5b67281c76f1 --- /dev/null +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -0,0 +1,121 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Optional + +import torch + +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.metrics.functional.r2score import ( + _r2score_update, + _r2score_compute +) + +class R2Score(Metric): + """ + Computes r2 score also known as coefficient of determination + + Forward accepts + + - ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) + - ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput) + + In the case of multioutput, as default the variances will be uniformly + averaged over the additional dimensions. Please see argument `multioutput` + for changing this behavior. + + Args: + num_outputs: + Number of outputs in multioutput setting (default is 1) + multioutput: + Defines aggregation in the case of multiple output scores. Can be one + of the following strings (default is `'uniform_average'`.): + + * `'raw_values'` returns full set of scores + * `'uniform_average'` scores are uniformly averaged + * `'variance_weighted'` scores are weighted by their individual variances + + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + + Example: + + >>> from pytorch_lightning.metrics import R2Score + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> r2score = R2Score() + >>> r2score(preds, target) + tensor(0.9486) + + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> r2score = R2Score(num_outputs=2, multioutput='raw_values') + >>> r2score(preds, target) + tensor([0.9654, 0.9082]) + """ + def __init__( + self, + num_outputs: int = 1, + multioutput: str = "uniform_average", + compute_on_step: bool = True, + dist_sync_on_step: bool = False, + process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + ): + super().__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + ) + + self.num_outputs = num_outputs + allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') + if multioutput not in allowed_multioutput: + raise ValueError( + f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}' + ) + self.multioutput = multioutput + + self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor): + """ + Update state with predictions and targets. + + Args: + preds: Predictions from model + target: Ground truth values + """ + sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) + + self.sum_squared_error += sum_squared_error + self.sum_error += sum_error + self.residual += residual + self.total += total + + def compute(self) -> torch.Tensor: + """ + Computes r2 score over the metric states. + """ + return _r2score_compute(self.sum_squared_error, self.sum_error, self.residual, + self.total, self.multioutput) + From 425489fc2e32171a28ad1f9e4dc50c87f3a602de Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 10 Nov 2020 12:32:14 +0100 Subject: [PATCH 02/12] change init --- pytorch_lightning/metrics/__init__.py | 1 + pytorch_lightning/metrics/functional/__init__.py | 2 ++ pytorch_lightning/metrics/regression/__init__.py | 1 + 3 files changed, 4 insertions(+) diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index 59f3a9cec06cf..875084e46de95 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -32,4 +32,5 @@ ExplainedVariance, PSNR, SSIM, + R2Score ) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index e38ab5f415c32..886a085b22686 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -43,3 +43,5 @@ from pytorch_lightning.metrics.functional.roc import roc from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity from pytorch_lightning.metrics.functional.ssim import ssim +from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix +from pytorch_lightning.metrics.functional.r2score import r2score diff --git a/pytorch_lightning/metrics/regression/__init__.py b/pytorch_lightning/metrics/regression/__init__.py index 3090b1fe712e8..988adcd39357f 100644 --- a/pytorch_lightning/metrics/regression/__init__.py +++ b/pytorch_lightning/metrics/regression/__init__.py @@ -17,3 +17,4 @@ from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance from pytorch_lightning.metrics.regression.psnr import PSNR from pytorch_lightning.metrics.regression.ssim import SSIM +from pytorch_lightning.metrics.regression.r2score import R2Score From 1ded8214ff12cc3a27892a69e44399026e6aaae5 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 10 Nov 2020 12:32:21 +0100 Subject: [PATCH 03/12] add test --- tests/metrics/regression/test_r2score.py | 67 ++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 tests/metrics/regression/test_r2score.py diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py new file mode 100644 index 0000000000000..e717d97dbe5c5 --- /dev/null +++ b/tests/metrics/regression/test_r2score.py @@ -0,0 +1,67 @@ +from collections import namedtuple +from functools import partial + +import pytest +import torch +from sklearn.metrics import r2_score as sk_r2score + +from pytorch_lightning.metrics.regression import R2Score +from pytorch_lightning.metrics.functional import r2score +from tests.metrics.utils import BATCH_SIZE, NUM_BATCHES, MetricTester + +torch.manual_seed(42) + +num_targets = 5 + +Input = namedtuple('Input', ["preds", "target"]) + +_single_target_inputs = Input(preds=torch.rand(NUM_BATCHES, BATCH_SIZE), target=torch.rand(NUM_BATCHES, BATCH_SIZE),) + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _single_target_sk_metric(preds, target, multioutput): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return sk_r2score(sk_target, sk_preds, multioutput=multioutput) + + +def _multi_target_sk_metric(preds, target, multioutput): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + return sk_r2score(sk_target, sk_preds, multioutput=multioutput) + + +@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) +@pytest.mark.parametrize( + "preds, target, sk_metric, num_outputs", + [ + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric, 1), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric, num_targets), + ], +) +class TestExplainedVariance(MetricTester): + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_explained_variance(self, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + R2Score, + partial(sk_metric, multioutput=multioutput), + dist_sync_on_step, + metric_args=dict(multioutput=multioutput, + num_outputs=num_outputs), + ) + + def test_explained_variance_functional(self, multioutput, preds, target, sk_metric, num_outputs): + self.run_functional_metric_test( + preds, + target, + r2score, + partial(sk_metric, multioutput=multioutput), + metric_args=dict(multioutput=multioutput), + ) From ecbb4addd061d35052ccec061c4b474c00e5ad1b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 10 Nov 2020 12:32:29 +0100 Subject: [PATCH 04/12] add docs --- docs/source/metrics.rst | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 387cbc3bd7482..c6a6e1a64a398 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -212,7 +212,7 @@ Classification Metrics Input types ----------- -For the purposes of classification metrics, inputs (predictions and targets) are split +For the purposes of classification metrics, inputs (predictions and targets) are split into these categories (``N`` stands for the batch size and ``C`` for number of classes): .. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1 @@ -227,10 +227,10 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c "Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``" .. note:: - All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so + All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``. -When predictions or targets are integers, it is assumed that class labels start at 0, i.e. +When predictions or targets are integers, it is assumed that class labels start at 0, i.e. the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types .. testcode:: @@ -507,6 +507,12 @@ SSIM :noindex: +R2Score +~~~~~~~ + +.. autoclass:: pytorch_lightning.metrics.regression.R2Score + :noindex: + Functional Metrics (Regression) ------------------------------- @@ -551,6 +557,13 @@ ssim [func] .. autofunction:: pytorch_lightning.metrics.functional.ssim :noindex: +r2score [func] +~~~~~~~~~~~~~~ + +.. autofunction:: pytorch_lightning.metrics.functional.r2score + :noindex: + + *** NLP *** From 16e1cfd795f26087a1dddf705290da70dad90d42 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 23 Dec 2020 10:27:46 +0100 Subject: [PATCH 05/12] add math --- .../metrics/functional/__init__.py | 3 +-- .../metrics/functional/r2score.py | 19 +++++++++++++------ .../metrics/regression/r2score.py | 12 +++++++++--- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 886a085b22686..8be2c2017d3d2 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -30,8 +30,8 @@ to_categorical, to_onehot, ) -from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # TODO: unify metrics between class and functional, add below +from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix from pytorch_lightning.metrics.functional.explained_variance import explained_variance from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 from pytorch_lightning.metrics.functional.mean_absolute_error import mean_absolute_error @@ -43,5 +43,4 @@ from pytorch_lightning.metrics.functional.roc import roc from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity from pytorch_lightning.metrics.functional.ssim import ssim -from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix from pytorch_lightning.metrics.functional.r2score import r2score diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index 1f0b07d0ed82a..3f71814afa973 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -23,13 +23,13 @@ def _r2score_update( target: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: _check_same_shape(preds, target) - if len(preds.shape)>2: - raise ValueError('Expected both prediction and target to 1D or 2D tensors,' + if preds.ndim > 2: + raise ValueError('Expected both prediction and target to be 1D or 2D tensors,' f' but recevied tensors with dimension {preds.shape}') sum_error = torch.sum(target, dim=0) sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) - residual = torch.sum(torch.pow(target-preds, 2.0), dim=0) + residual = torch.sum(torch.pow(target - preds, 2.0), dim=0) total = torch.sum(torch.ones_like(target), dim=0) return sum_squared_error, sum_error, residual, total @@ -55,13 +55,20 @@ def _r2score_compute(sum_squared_error: torch.Tensor, raise ValueError('Argument `multioutput` must be either `raw_values`,' f' `uniform_average` or `variance_weighted`. Received {multioutput}.') + def r2score( preds: torch.Tensor, target: torch.Tensor, multioutput: str = "uniform_average", ) -> torch.Tensor: - """ - Computes r2 score also known as coefficient of determination + r""" + Computes r2 score also known as `coefficient of determination + `_: + + .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} + + where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and + :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Args: pred: estimated labels @@ -87,4 +94,4 @@ def r2score( tensor([0.9654, 0.9082]) """ sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) - return _r2score_compute(sum_squared_error, sum_error, residual, total, multioutput) \ No newline at end of file + return _r2score_compute(sum_squared_error, sum_error, residual, total, multioutput) diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index a5b67281c76f1..c8f01d0a17f91 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -21,9 +21,16 @@ _r2score_compute ) + class R2Score(Metric): - """ - Computes r2 score also known as coefficient of determination + r""" + Computes r2 score also known as `coefficient of determination + `_: + + .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} + + where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and + :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Forward accepts @@ -118,4 +125,3 @@ def compute(self) -> torch.Tensor: """ return _r2score_compute(self.sum_squared_error, self.sum_error, self.residual, self.total, self.multioutput) - From 656cf6cfb833b011a95affc6c3e0108c6040f349 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 23 Dec 2020 22:13:29 +0100 Subject: [PATCH 06/12] Apply suggestions from code review Co-authored-by: Teddy Koker Co-authored-by: Rohit Gupta --- tests/metrics/regression/test_r2score.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py index e717d97dbe5c5..1dbcd1cecd346 100644 --- a/tests/metrics/regression/test_r2score.py +++ b/tests/metrics/regression/test_r2score.py @@ -42,10 +42,10 @@ def _multi_target_sk_metric(preds, target, multioutput): (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric, num_targets), ], ) -class TestExplainedVariance(MetricTester): +class TestR2Score(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_explained_variance(self, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): + def test_r2(self, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, @@ -57,7 +57,7 @@ def test_explained_variance(self, multioutput, preds, target, sk_metric, num_out num_outputs=num_outputs), ) - def test_explained_variance_functional(self, multioutput, preds, target, sk_metric, num_outputs): + def test_r2_functional(self, multioutput, preds, target, sk_metric, num_outputs): self.run_functional_metric_test( preds, target, From 11fe522375c759e0f4f7bf1103be788483d16222 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 26 Dec 2020 12:18:38 +0100 Subject: [PATCH 07/12] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a46f49211268..0014d600383b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838)) +- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241)) + ### Changed From 74b8b15dbe2de458cdc2d3fa56b037f2465d0965 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 26 Dec 2020 12:42:09 +0100 Subject: [PATCH 08/12] adjusted parameter --- .../metrics/functional/__init__.py | 2 +- .../metrics/functional/r2score.py | 38 ++++++++++++++----- .../metrics/regression/r2score.py | 20 +++++++++- tests/metrics/regression/test_r2score.py | 29 +++++++++----- 4 files changed, 67 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/metrics/functional/__init__.py b/pytorch_lightning/metrics/functional/__init__.py index 37442fa301ab8..62097b176b795 100644 --- a/pytorch_lightning/metrics/functional/__init__.py +++ b/pytorch_lightning/metrics/functional/__init__.py @@ -41,7 +41,7 @@ from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401 from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401 from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401 -from pytorch_lightning.metrics.functional.r2score import r2score +from pytorch_lightning.metrics.functional.r2score import r2score # noqa: F401 from pytorch_lightning.metrics.functional.roc import roc # noqa: F401 from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401 from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401 diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index 3f71814afa973..aad748ab3ae59 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -26,6 +26,8 @@ def _r2score_update( if preds.ndim > 2: raise ValueError('Expected both prediction and target to be 1D or 2D tensors,' f' but recevied tensors with dimension {preds.shape}') + if len(preds) < 2: + raise ValueError('Needs atleast two samples to calculate r2 score.') sum_error = torch.sum(target, dim=0) sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) @@ -39,26 +41,36 @@ def _r2score_compute(sum_squared_error: torch.Tensor, sum_error: torch.Tensor, residual: torch.Tensor, total: torch.Tensor, + adjusted: int = 0, multioutput: str = "uniform_average") -> torch.Tensor: mean_error = sum_error / total diff = sum_squared_error - sum_error * mean_error raw_scores = 1 - (residual / diff) if multioutput == "raw_values": - return raw_scores - if multioutput == "uniform_average": - return torch.mean(raw_scores) - if multioutput == "variance_weighted": + r2score = raw_scores + elif multioutput == "uniform_average": + r2score = torch.mean(raw_scores) + elif multioutput == "variance_weighted": diff_sum = torch.sum(diff) - return torch.sum(diff / diff_sum * raw_scores) + r2score = torch.sum(diff / diff_sum * raw_scores) + else: + raise ValueError('Argument `multioutput` must be either `raw_values`,' + f' `uniform_average` or `variance_weighted`. Received {multioutput}.') - raise ValueError('Argument `multioutput` must be either `raw_values`,' - f' `uniform_average` or `variance_weighted`. Received {multioutput}.') + if adjusted < 0 or not isinstance(adjusted, int): + raise ValueError('`adjusted` parameter should be an integer larger or' + ' equal to 0.') + + if adjusted != 0: + r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1) + return r2score def r2score( preds: torch.Tensor, target: torch.Tensor, + adjusted: int = 0, multioutput: str = "uniform_average", ) -> torch.Tensor: r""" @@ -68,11 +80,19 @@ def r2score( .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and - :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. + :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate + adjusted r2 score given by + + .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} + + where the parameter :math:`k` (the number of independent regressors) should + be provided as the `adjusted` argument. Args: pred: estimated labels target: ground truth labels + adjusted: number of independent regressors for calculating adjusted r2 score. + Default 0 (standard r2 score). multioutput: Defines aggregation in the case of multiple output scores. Can be one of the following strings (default is `'uniform_average'`.): @@ -94,4 +114,4 @@ def r2score( tensor([0.9654, 0.9082]) """ sum_squared_error, sum_error, residual, total = _r2score_update(preds, target) - return _r2score_compute(sum_squared_error, sum_error, residual, total, multioutput) + return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput) diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index c8f01d0a17f91..40637c65f4606 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -30,7 +30,13 @@ class R2Score(Metric): .. math:: R^2 = 1 - \frac{SS_res}{SS_tot} where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and - :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. + :math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate + adjusted r2 score given by + + .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} + + where the parameter :math:`k` (the number of independent regressors) should + be provided as the `adjusted` argument. Forward accepts @@ -44,6 +50,9 @@ class R2Score(Metric): Args: num_outputs: Number of outputs in multioutput setting (default is 1) + adjusted: + number of independent regressors for calculating adjusted r2 score. + Default 0 (standard r2 score). multioutput: Defines aggregation in the case of multiple output scores. Can be one of the following strings (default is `'uniform_average'`.): @@ -78,6 +87,7 @@ class R2Score(Metric): def __init__( self, num_outputs: int = 1, + adjusted: int = 0, multioutput: str = "uniform_average", compute_on_step: bool = True, dist_sync_on_step: bool = False, @@ -92,6 +102,12 @@ def __init__( ) self.num_outputs = num_outputs + + if adjusted < 0 or not isinstance(adjusted, int): + raise ValueError('`adjusted` parameter should be an integer larger or' + ' equal to 0.') + self.adjusted = adjusted + allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') if multioutput not in allowed_multioutput: raise ValueError( @@ -124,4 +140,4 @@ def compute(self) -> torch.Tensor: Computes r2 score over the metric states. """ return _r2score_compute(self.sum_squared_error, self.sum_error, self.residual, - self.total, self.multioutput) + self.total, self.adjusted, self.multioutput) diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py index 1dbcd1cecd346..7d9d382fe055b 100644 --- a/tests/metrics/regression/test_r2score.py +++ b/tests/metrics/regression/test_r2score.py @@ -22,18 +22,25 @@ ) -def _single_target_sk_metric(preds, target, multioutput): +def _single_target_sk_metric(preds, target, adjusted, multioutput): sk_preds = preds.view(-1).numpy() sk_target = target.view(-1).numpy() - return sk_r2score(sk_target, sk_preds, multioutput=multioutput) + r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) + if adjusted != 0: + r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) + return r2_score -def _multi_target_sk_metric(preds, target, multioutput): +def _multi_target_sk_metric(preds, target, adjusted, multioutput): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - return sk_r2score(sk_target, sk_preds, multioutput=multioutput) + r2_score = sk_r2score(sk_target, sk_preds, multioutput=multioutput) + if adjusted != 0: + r2_score = 1 - (1 - r2_score) * (sk_preds.shape[0] - 1) / (sk_preds.shape[0] - adjusted - 1) + return r2_score +@pytest.mark.parametrize("adjusted", [0, 5, 10]) @pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted']) @pytest.mark.parametrize( "preds, target, sk_metric, num_outputs", @@ -45,23 +52,25 @@ def _multi_target_sk_metric(preds, target, multioutput): class TestR2Score(MetricTester): @pytest.mark.parametrize("ddp", [True, False]) @pytest.mark.parametrize("dist_sync_on_step", [True, False]) - def test_r2(self, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): + def test_r2(self, adjusted, multioutput, preds, target, sk_metric, num_outputs, ddp, dist_sync_on_step): self.run_class_metric_test( ddp, preds, target, R2Score, - partial(sk_metric, multioutput=multioutput), + partial(sk_metric, adjusted=adjusted, multioutput=multioutput), dist_sync_on_step, - metric_args=dict(multioutput=multioutput, + metric_args=dict(adjusted=adjusted, + multioutput=multioutput, num_outputs=num_outputs), ) - def test_r2_functional(self, multioutput, preds, target, sk_metric, num_outputs): + def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, num_outputs): self.run_functional_metric_test( preds, target, r2score, - partial(sk_metric, multioutput=multioutput), - metric_args=dict(multioutput=multioutput), + partial(sk_metric, adjusted=adjusted, multioutput=multioutput), + metric_args=dict(adjusted=adjusted, + multioutput=multioutput), ) From cea47d95cd16d2e74def928a850db76e2d2820aa Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Dec 2020 18:37:40 +0100 Subject: [PATCH 09/12] add more test --- tests/metrics/regression/test_r2score.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py index 7d9d382fe055b..5fe54828f2cc2 100644 --- a/tests/metrics/regression/test_r2score.py +++ b/tests/metrics/regression/test_r2score.py @@ -74,3 +74,22 @@ def test_r2_functional(self, adjusted, multioutput, preds, target, sk_metric, nu metric_args=dict(adjusted=adjusted, multioutput=multioutput), ) + + +def test_error_on_different_shape(metric_class=R2Score): + metric = metric_class() + with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'): + metric(torch.randn(100,), torch.randn(50,)) + + +def test_error_on_multidim_tensors(metric_class=R2Score): + metric = metric_class() + with pytest.raises(ValueError, match=r'Expected both prediction and target to be 1D or 2D tensors,' + r' but recevied tensors with dimension .'): + metric(torch.randn(10,20,5), torch.randn(10,20,5)) + + +def test_error_on_too_few_samples(metric_class=R2Score): + metric = metric_class() + with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'): + metric(torch.randn(1,), torch.randn(1,)) From 47de40312e41e2ef256763d49b031ec87eea5566 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Dec 2020 18:39:07 +0100 Subject: [PATCH 10/12] pep8 --- tests/metrics/regression/test_r2score.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py index 5fe54828f2cc2..2ec785c6fd21f 100644 --- a/tests/metrics/regression/test_r2score.py +++ b/tests/metrics/regression/test_r2score.py @@ -86,7 +86,7 @@ def test_error_on_multidim_tensors(metric_class=R2Score): metric = metric_class() with pytest.raises(ValueError, match=r'Expected both prediction and target to be 1D or 2D tensors,' r' but recevied tensors with dimension .'): - metric(torch.randn(10,20,5), torch.randn(10,20,5)) + metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5)) def test_error_on_too_few_samples(metric_class=R2Score): From a4d684773d1d00f5f891044b04174da6be25c1ed Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 31 Dec 2020 15:07:56 +0100 Subject: [PATCH 11/12] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/metrics/functional/r2score.py | 12 ++++++------ pytorch_lightning/metrics/regression/r2score.py | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index aad748ab3ae59..e0fcacd3dd83a 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -32,7 +32,7 @@ def _r2score_update( sum_error = torch.sum(target, dim=0) sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0) residual = torch.sum(torch.pow(target - preds, 2.0), dim=0) - total = torch.sum(torch.ones_like(target), dim=0) + total = target.size(0) return sum_squared_error, sum_error, residual, total @@ -86,7 +86,7 @@ def r2score( .. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1} where the parameter :math:`k` (the number of independent regressors) should - be provided as the `adjusted` argument. + be provided as the ``adjusted`` argument. Args: pred: estimated labels @@ -94,11 +94,11 @@ def r2score( adjusted: number of independent regressors for calculating adjusted r2 score. Default 0 (standard r2 score). multioutput: Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): + of the following strings (default is ``'uniform_average'``.): - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances + * ``'raw_values'`` returns full set of scores + * ``'uniform_average'`` scores are uniformly averaged + * ``'variance_weighted'`` scores are weighted by their individual variances Example: diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 40637c65f4606..8394eef75f09c 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -55,11 +55,11 @@ class R2Score(Metric): Default 0 (standard r2 score). multioutput: Defines aggregation in the case of multiple output scores. Can be one - of the following strings (default is `'uniform_average'`.): + of the following strings (default is ``'uniform_average'``.): - * `'raw_values'` returns full set of scores - * `'uniform_average'` scores are uniformly averaged - * `'variance_weighted'` scores are weighted by their individual variances + * ``'raw_values'`` returns full set of scores + * ``'uniform_average'`` scores are uniformly averaged + * ``'variance_weighted'`` scores are weighted by their individual variances compute_on_step: Forward only calls ``update()`` and return None if this is set to False. default: True From 0218f5dce346343e7284e9e98b057b8d60b8627c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Thu, 31 Dec 2020 15:27:49 +0100 Subject: [PATCH 12/12] add warnings for adjusted score --- pytorch_lightning/metrics/functional/r2score.py | 11 ++++++++++- pytorch_lightning/metrics/regression/r2score.py | 2 +- tests/metrics/regression/test_r2score.py | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index e0fcacd3dd83a..f689e3ac9cac1 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning.metrics.utils import _check_same_shape +from pytorch_lightning.utilities import rank_zero_warn def _r2score_update( @@ -63,7 +64,15 @@ def _r2score_compute(sum_squared_error: torch.Tensor, ' equal to 0.') if adjusted != 0: - r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1) + if adjusted > total - 1: + rank_zero_warn("More independent regressions than datapoints in" + " adjusted r2 score. Falls back to standard r2 score.", + UserWarning) + elif adjusted == total - 1: + rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" + " standard r2 score.", UserWarning) + else: + r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1) return r2score diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 8394eef75f09c..f8f6e98b790c4 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -118,7 +118,7 @@ def __init__( self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") - self.add_state("total", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, preds: torch.Tensor, target: torch.Tensor): """ diff --git a/tests/metrics/regression/test_r2score.py b/tests/metrics/regression/test_r2score.py index 2ec785c6fd21f..ef3ec89721b99 100644 --- a/tests/metrics/regression/test_r2score.py +++ b/tests/metrics/regression/test_r2score.py @@ -93,3 +93,17 @@ def test_error_on_too_few_samples(metric_class=R2Score): metric = metric_class() with pytest.raises(ValueError, match='Needs atleast two samples to calculate r2 score.'): metric(torch.randn(1,), torch.randn(1,)) + + +def test_warning_on_too_large_adjusted(metric_class=R2Score): + metric = metric_class(adjusted=10) + + with pytest.warns(UserWarning, + match="More independent regressions than datapoints in" + " adjusted r2 score. Falls back to standard r2 score."): + metric(torch.randn(10,), torch.randn(10,)) + + with pytest.warns(UserWarning, + match="Division by zero in adjusted r2 score. Falls back to" + " standard r2 score."): + metric(torch.randn(11,), torch.randn(11,))