Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update explained variance metric #4024

Merged
merged 6 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,8 @@ def __setstate__(self, state):
self.__dict__.update(state)
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)

def _check_same_shape(self, pred: torch.Tensor, target: torch.Tensor):
""" Check that predictions and target have the same shape, else raise error """
if pred.shape != target.shape:
raise RuntimeError('Predictions and targets are expected to have the same shape')
67 changes: 62 additions & 5 deletions pytorch_lightning/metrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,40 @@
import torch
from typing import Any, Callable, Optional, Union
from typing import Any, Optional

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_warn


class ExplainedVariance(Metric):
"""
Computes explained variance.

Forward accepts

- ``preds`` (float tensor): ``(N,)`` or ``(N, ...)`` (multioutput)
- ``target`` (long tensor): ``(N,)`` or ``(N, ...)`` (multioutput)

In the case of multioutput, as default the variances will be uniformly
averaged over the additional dimensions. Please see argument `multioutput`
for changing this behavior.

Args:
multioutput:
Defines aggregation in the case of multiple output scores. Can be one
of the following strings (default is `'uniform_average'`.):

* `'raw_values'` returns full set of scores
* `'uniform_average'` scores are uniformly averaged
* `'variance_weighted'` scores are weighted by their individual variances

compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
ddp_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example:

>>> from pytorch_lightning.metrics import ExplainedVariance
Expand All @@ -17,11 +44,16 @@ class ExplainedVariance(Metric):
>>> explained_variance(preds, target)
tensor(0.9572)


>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> explained_variance = ExplainedVariance(multioutput='raw_values')
>>> explained_variance(preds, target)
tensor([0.9677, 1.0000])
"""

def __init__(
self,
multioutput: str = 'uniform_average',
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
Expand All @@ -31,10 +63,19 @@ def __init__(
ddp_sync_on_step=ddp_sync_on_step,
process_group=process_group,
)

allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
if multioutput not in allowed_multioutput:
raise ValueError(
f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}'
)
self.multioutput = multioutput
self.add_state("y", default=[], dist_reduce_fx=None)
self.add_state("y_pred", default=[], dist_reduce_fx=None)

rank_zero_warn('Metric `ExplainedVariance` will save all targets and'
' predictions in buffer. For large datasets this may lead'
' to large memory footprint.')

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Expand All @@ -43,6 +84,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
self._check_same_shape(preds, target)
self.y.append(target)
self.y_pred.append(preds)

Expand All @@ -59,5 +101,20 @@ def compute(self):
y_true_avg = torch.mean(y_true, dim=0)
denominator = torch.mean((y_true - y_true_avg) ** 2, dim=0)

# TODO: multioutput
return 1.0 - torch.mean(numerator / denominator)
# Take care of division by zero
nonzero_numerator = numerator != 0
nonzero_denominator = denominator != 0
valid_score = nonzero_numerator & nonzero_denominator
output_scores = torch.ones_like(y_diff_avg)
output_scores[valid_score] = 1.0 - (numerator[valid_score] / denominator[valid_score])
output_scores[nonzero_numerator & ~nonzero_denominator] = 0.

# Decide what to do in multioutput case
# Todo: allow user to pass in tensor with weights
if self.multioutput == 'raw_values':
return output_scores
if self.multioutput == 'uniform_average':
return torch.mean(output_scores)
if self.multioutput == 'variance_weighted':
denom_sum = torch.sum(denominator)
return torch.sum(denominator / denom_sum * output_scores)
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape, \
'Predictions and targets are expected to have the same shape'
self._check_same_shape(preds, target)
abs_error = torch.abs(preds - target)

self.sum_abs_error += torch.sum(abs_error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape, \
'Predictions and targets are expected to have the same shape'
self._check_same_shape(preds, target)
squared_error = torch.pow(preds - target, 2)

self.sum_squared_error += torch.sum(squared_error)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
assert preds.shape == target.shape, \
'Predictions and targets are expected to have the same shape'
self._check_same_shape(preds, target)
squared_log_error = torch.pow(torch.log1p(preds) - torch.log1p(target), 2)

self.sum_squared_log_error += torch.sum(squared_log_error)
Expand Down
4 changes: 1 addition & 3 deletions tests/metrics/classification/test_accuracy.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import os
import pytest
import torch
import os
import numpy as np
from collections import namedtuple

from pytorch_lightning.metrics.classification.accuracy import Accuracy
from sklearn.metrics import accuracy_score

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
from tests.metrics.utils import compute_batch, NUM_BATCHES, BATCH_SIZE

torch.manual_seed(42)

Expand Down
23 changes: 19 additions & 4 deletions tests/metrics/regression/test_explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from pytorch_lightning.metrics.regression import ExplainedVariance
from sklearn.metrics import explained_variance_score

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
from tests.metrics.utils import compute_batch, NUM_BATCHES, BATCH_SIZE

torch.manual_seed(42)

Expand Down Expand Up @@ -40,12 +39,28 @@ def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score):

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("ddp_sync_on_step", [True, False])
@pytest.mark.parametrize("multioutput", ['raw_values', 'uniform_average', 'variance_weighted'])
@pytest.mark.parametrize(
"preds, target, sk_metric",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric),
],
)
def test_explained_variance(ddp, ddp_sync_on_step, preds, target, sk_metric):
compute_batch(preds, target, ExplainedVariance, sk_metric, ddp_sync_on_step, ddp)
def test_explained_variance(ddp, ddp_sync_on_step, multioutput, preds, target, sk_metric):
compute_batch(
preds,
target,
ExplainedVariance,
partial(sk_metric, sk_fn=partial(explained_variance_score, multioutput=multioutput)),
ddp_sync_on_step,
ddp,
metric_args=dict(multioutput=multioutput),
)


def test_error_on_different_shape(metric_class=ExplainedVariance):
metric = metric_class()
with pytest.raises(RuntimeError,
match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100,), torch.randn(50,))
11 changes: 9 additions & 2 deletions tests/metrics/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError
from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_squared_log_error

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE
from tests.metrics.utils import compute_batch, NUM_BATCHES, BATCH_SIZE

torch.manual_seed(42)

Expand Down Expand Up @@ -57,3 +56,11 @@ def _multi_target_sk_metric(preds, target, sk_fn=mean_squared_error):
)
def test_mean_error(ddp, ddp_sync_on_step, preds, target, sk_metric, metric_class, sk_fn):
compute_batch(preds, target, metric_class, partial(sk_metric, sk_fn=sk_fn), ddp_sync_on_step, ddp)


@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError])
def test_error_on_different_shape(metric_class):
metric = metric_class()
with pytest.raises(RuntimeError,
match='Predictions and targets are expected to have the same shape'):
metric(torch.randn(100,), torch.randn(50,))
46 changes: 41 additions & 5 deletions tests/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,39 @@
import torch
import numpy as np
import os
import sys
import pytest
import pickle
from typing import Callable

import torch
import numpy as np

from pytorch_lightning.metrics import Metric

NUM_PROCESSES = 2
NUM_BATCHES = 10
BATCH_SIZE = 16


def setup_ddp(rank, world_size):
""" Setup ddp enviroment """
os.environ["MASTER_ADDR"] = 'localhost'
os.environ['MASTER_PORT'] = '8088'
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)


def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_step, worldsize=1, metric_args={}):

def _compute_batch(rank: int,
preds: torch.Tensor,
target: torch.Tensor,
metric_class: Metric,
sk_metric: Callable,
ddp_sync_on_step: bool,
worldsize: int = 1,
metric_args: dict = {}
):
""" Utility function doing the actual comparison between lightning metric
and reference metric
"""
# Instanciate lightning metric
metric = metric_class(compute_on_step=True, ddp_sync_on_step=ddp_sync_on_step, **metric_args)

# verify metrics work after being loaded from pickled state
Expand Down Expand Up @@ -52,7 +68,27 @@ def _compute_batch(rank, preds, target, metric_class, sk_metric, ddp_sync_on_ste
assert np.allclose(result.numpy(), sk_result)


def compute_batch(preds, target, metric_class, sk_metric, ddp_sync_on_step, ddp=False, metric_args={}):
def compute_batch(preds: torch.Tensor,
target: torch.Tensor,
metric_class: Metric,
sk_metric: Callable,
ddp_sync_on_step: bool,
ddp: bool = False,
metric_args: dict = {}
):
""" Utility function for comparing the result between a lightning class
metric and another metric (often sklearns)

Args:
preds: prediction tensor
target: target tensor
metric_class: lightning metric class to test
sk_metric: function to compare with
ddp_sync_on_step: bool, determine if values should be reduce on step
ddp: bool, determine if test should run in ddp mode
metric_args: dict, additional kwargs that are use when instanciating
the lightning metric
"""
if ddp:
if sys.platform == "win32":
pytest.skip("DDP not supported on windows")
Expand Down