Skip to content

Commit

Permalink
Fix RMSLE metric (#3188)
Browse files Browse the repository at this point in the history
* fix rmsle

* Updated test to match rmsle fix

* Updated RMSLE example result to match functional

* chlog

* add randomized test

* fix pep8

Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
  • Loading branch information
3 people authored Aug 26, 2020
1 parent ae3bf91 commit 888340d
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917))

- Fixed RMSLE metric ([#3188](https://github.com/PyTorchLightning/pytorch-lightning/pull/3188))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ def rmsle(
>>> x = torch.tensor([0., 1, 2, 3])
>>> y = torch.tensor([0., 1, 2, 2])
>>> rmsle(x, y)
tensor(0.0207)
tensor(0.1438)
"""
rmsle = mse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction)
rmsle = rmse(torch.log(pred + 1), torch.log(target + 1), reduction=reduction)
return rmsle


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class RMSLE(Metric):
>>> target = torch.tensor([0., 1, 2, 2])
>>> metric = RMSLE()
>>> metric(pred, target)
tensor(0.0207)
tensor(0.1438)
"""

Expand Down
38 changes: 34 additions & 4 deletions tests/metrics/functional/test_regression.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import numpy as np
import pytest
import torch
from skimage.metrics import peak_signal_noise_ratio as ski_psnr
from skimage.metrics import structural_similarity as ski_ssim
from functools import partial
from math import sqrt
from skimage.metrics import (
peak_signal_noise_ratio as ski_psnr,
structural_similarity as ski_ssim
)
from sklearn.metrics import (
mean_absolute_error as mae_sk,
mean_squared_error as mse_sk,
mean_squared_log_error as msle_sk
)

from pytorch_lightning.metrics.functional import (
mae,
Expand All @@ -14,6 +23,27 @@
)


@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(mae_sk, mae, id='mean_absolute_error'),
pytest.param(mse_sk, mse, id='mean_squared_error'),
pytest.param(partial(mse_sk, squared=False), rmse, id='root_mean_squared_error'),
pytest.param(lambda x, y: sqrt(msle_sk(x, y)), rmsle, id='root_mean_squared_log_error')
])
def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# iterate over different label counts in predictions and target
pred = torch.rand(300, device=device)
target = torch.rand(300, device=device)

sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy())
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target)
assert torch.allclose(sk_score, pl_score)


@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.25),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 3.0),
Expand Down Expand Up @@ -45,8 +75,8 @@ def test_mae(pred, target, expected):

@pytest.mark.parametrize(['pred', 'target', 'expected'], [
pytest.param([0., 1, 2, 3], [0., 1, 2, 3], 0.0),
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.0207),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.2841),
pytest.param([0., 1, 2, 3], [0., 1, 2, 2], 0.1438),
pytest.param([4., 3, 2, 1], [1., 4, 3, 2], 0.5330),
])
def test_rmsle(pred, target, expected):
score = rmsle(torch.tensor(pred), torch.tensor(target))
Expand Down

0 comments on commit 888340d

Please sign in to comment.