Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] R2Score #5241

Merged
merged 15 commits into from
Jan 1, 2021
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `max_fpr` parameter to `auroc` metric for computing partial auroc metric ([#3790](https://github.com/PyTorchLightning/pytorch-lightning/pull/3790))

- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241))

### Changed


Expand Down
19 changes: 16 additions & 3 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ Classification Metrics
Input types
-----------

For the purposes of classification metrics, inputs (predictions and targets) are split
For the purposes of classification metrics, inputs (predictions and targets) are split
into these categories (``N`` stands for the batch size and ``C`` for number of classes):

.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
Expand All @@ -227,10 +227,10 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"

.. note::
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.

When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types

.. testcode::
Expand Down Expand Up @@ -517,6 +517,12 @@ SSIM
:noindex:


R2Score
~~~~~~~

.. autoclass:: pytorch_lightning.metrics.regression.R2Score
:noindex:

Functional Metrics (Regression)
-------------------------------

Expand Down Expand Up @@ -561,6 +567,13 @@ ssim [func]
.. autofunction:: pytorch_lightning.metrics.functional.ssim
:noindex:

r2score [func]
~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.r2score
:noindex:


***
NLP
***
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@
ExplainedVariance,
PSNR,
SSIM,
R2Score
)
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
to_categorical,
to_onehot,
)
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
# TODO: unify metrics between class and functional, add below
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
Expand All @@ -41,6 +41,7 @@
from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401
from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401
from pytorch_lightning.metrics.functional.r2score import r2score # noqa: F401
from pytorch_lightning.metrics.functional.roc import roc # noqa: F401
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401
from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401
117 changes: 117 additions & 0 deletions pytorch_lightning/metrics/functional/r2score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch

from pytorch_lightning.metrics.utils import _check_same_shape


def _r2score_update(
preds: torch.tensor,
target: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
_check_same_shape(preds, target)
if preds.ndim > 2:
raise ValueError('Expected both prediction and target to be 1D or 2D tensors,'
f' but recevied tensors with dimension {preds.shape}')
if len(preds) < 2:
raise ValueError('Needs atleast two samples to calculate r2 score.')

sum_error = torch.sum(target, dim=0)
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
residual = torch.sum(torch.pow(target - preds, 2.0), dim=0)
total = torch.sum(torch.ones_like(target), dim=0)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

return sum_squared_error, sum_error, residual, total


def _r2score_compute(sum_squared_error: torch.Tensor,
sum_error: torch.Tensor,
residual: torch.Tensor,
total: torch.Tensor,
adjusted: int = 0,
multioutput: str = "uniform_average") -> torch.Tensor:
mean_error = sum_error / total
diff = sum_squared_error - sum_error * mean_error
raw_scores = 1 - (residual / diff)

if multioutput == "raw_values":
r2score = raw_scores
elif multioutput == "uniform_average":
r2score = torch.mean(raw_scores)
elif multioutput == "variance_weighted":
diff_sum = torch.sum(diff)
r2score = torch.sum(diff / diff_sum * raw_scores)
else:
raise ValueError('Argument `multioutput` must be either `raw_values`,'
f' `uniform_average` or `variance_weighted`. Received {multioutput}.')

if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError('`adjusted` parameter should be an integer larger or'
' equal to 0.')

if adjusted != 0:
r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
return r2score


def r2score(
preds: torch.Tensor,
target: torch.Tensor,
adjusted: int = 0,
multioutput: str = "uniform_average",
) -> torch.Tensor:
r"""
Computes r2 score also known as `coefficient of determination
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:

.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}

where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
adjusted r2 score given by

.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}

where the parameter :math:`k` (the number of independent regressors) should
be provided as the `adjusted` argument.
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

Args:
pred: estimated labels
target: ground truth labels
adjusted: number of independent regressors for calculating adjusted r2 score.
Default 0 (standard r2 score).
multioutput: Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is `'uniform_average'`.):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
of the following strings (default is `'uniform_average'`.):
of the following strings (default is ``'uniform_average'``.):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should defaults be removed from the docs??


* `'raw_values'` returns full set of scores
* `'uniform_average'` scores are uniformly averaged
* `'variance_weighted'` scores are weighted by their individual variances
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

Example:

>>> from pytorch_lightning.metrics.functional import r2score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> r2score(preds, target)
tensor(0.9486)

>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score(preds, target, multioutput='raw_values')
tensor([0.9654, 0.9082])
"""
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)
return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput)
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance # noqa: F401
from pytorch_lightning.metrics.regression.psnr import PSNR # noqa: F401
from pytorch_lightning.metrics.regression.ssim import SSIM # noqa: F401
from pytorch_lightning.metrics.regression.r2score import R2Score # noqa: F401
143 changes: 143 additions & 0 deletions pytorch_lightning/metrics/regression/r2score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.r2score import (
_r2score_update,
_r2score_compute
)


class R2Score(Metric):
r"""
Computes r2 score also known as `coefficient of determination
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:

.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}

where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
adjusted r2 score given by

.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}

where the parameter :math:`k` (the number of independent regressors) should
be provided as the `adjusted` argument.

Forward accepts

- ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput)
- ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput)

In the case of multioutput, as default the variances will be uniformly
averaged over the additional dimensions. Please see argument `multioutput`
for changing this behavior.

Args:
num_outputs:
Number of outputs in multioutput setting (default is 1)
adjusted:
number of independent regressors for calculating adjusted r2 score.
Default 0 (standard r2 score).
multioutput:
Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is `'uniform_average'`.):

* `'raw_values'` returns full set of scores
* `'uniform_average'` scores are uniformly averaged
* `'variance_weighted'` scores are weighted by their individual variances
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example:

>>> from pytorch_lightning.metrics import R2Score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> r2score = R2Score()
>>> r2score(preds, target)
tensor(0.9486)

>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(num_outputs=2, multioutput='raw_values')
>>> r2score(preds, target)
tensor([0.9654, 0.9082])
"""
def __init__(
self,
num_outputs: int = 1,
adjusted: int = 0,
multioutput: str = "uniform_average",
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.num_outputs = num_outputs

if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError('`adjusted` parameter should be an integer larger or'
' equal to 0.')
self.adjusted = adjusted

allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
if multioutput not in allowed_multioutput:
raise ValueError(
f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}'
)
self.multioutput = multioutput

self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)

self.sum_squared_error += sum_squared_error
self.sum_error += sum_error
self.residual += residual
self.total += total

def compute(self) -> torch.Tensor:
"""
Computes r2 score over the metric states.
"""
return _r2score_compute(self.sum_squared_error, self.sum_error, self.residual,
self.total, self.adjusted, self.multioutput)
Loading