Skip to content

Commit

Permalink
[Metrics] AUROC error on multilabel + improved testing (#3350)
Browse files Browse the repository at this point in the history
* error on multilabel

* fix tests

* fix pep8

* changelog

* update doc test

* fix doctest

* fix doctest

* update from suggestion

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update test_classification.py

* Update test_classification.py

* retrigger test

* 'pep8

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
SkafteNicki and awaelchli authored Sep 21, 2020
1 parent c346679 commit b1347c9
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 30 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* functional interface ([#3349](https://github.com/PyTorchLightning/pytorch-lightning/pull/3349))
* class based interface + tests ([#3358](https://github.com/PyTorchLightning/pytorch-lightning/pull/3358))

- Added error when AUROC metric is used for multiclass problems ([#3350](https://github.com/PyTorchLightning/pytorch-lightning/pull/3350))

### Changed

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,10 @@ class AUROC(TensorMetric):
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> target = torch.tensor([0, 1, 1, 0])
>>> metric = AUROC()
>>> metric(pred, target)
tensor(0.3333)
tensor(0.5000)
"""

Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,12 @@ def roc(
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> y = torch.tensor([0, 1, 1, 1])
>>> fpr, tpr, thresholds = roc(x, y)
>>> fpr
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000])
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0., 0., 0., 1., 1.])
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])
Expand Down Expand Up @@ -682,12 +682,12 @@ def precision_recall_curve(
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> target = torch.tensor([0, 1, 1, 0])
>>> precision, recall, thresholds = precision_recall_curve(pred, target)
>>> precision
tensor([0.3333, 0.0000, 0.0000, 1.0000])
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1., 0., 0., 0.])
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([1, 2, 3])
Expand Down Expand Up @@ -858,10 +858,14 @@ def auroc(
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> y = torch.tensor([0, 1, 1, 0])
>>> auroc(x, y)
tensor(0.3333)
tensor(0.5000)
"""
if any(target > 1):
raise ValueError('AUROC metric is meant for binary classification, but'
' target tensor contains value different from 0 and 1.'
' Multiclass is currently not supported.')

@auc_decorator(reorder=True)
def _auroc(pred, target, sample_weight, pos_label):
Expand Down
61 changes: 43 additions & 18 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
f1_score as sk_f1_score,
fbeta_score as sk_fbeta_score,
confusion_matrix as sk_confusion_matrix,
roc_curve as sk_roc_curve,
roc_auc_score as sk_roc_auc_score,
precision_recall_curve as sk_precision_recall_curve
)

from pytorch_lightning import seed_everything
Expand All @@ -36,29 +39,44 @@
)


@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(sk_accuracy, accuracy, id='accuracy'),
pytest.param(partial(sk_jaccard_score, average='macro'), iou, id='iou'),
pytest.param(partial(sk_precision, average='micro'), precision, id='precision'),
pytest.param(partial(sk_recall, average='micro'), recall, id='recall'),
pytest.param(partial(sk_f1_score, average='micro'), f1_score, id='f1_score'),
pytest.param(partial(sk_fbeta_score, average='micro', beta=2), partial(fbeta_score, beta=2), id='fbeta_score'),
pytest.param(sk_confusion_matrix, confusion_matrix, id='confusion_matrix')
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [
pytest.param(sk_accuracy, accuracy, False, id='accuracy'),
pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'),
pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'),
pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'),
pytest.param(partial(sk_f1_score, average='micro'), f1_score, False, id='f1_score'),
pytest.param(partial(sk_fbeta_score, average='micro', beta=2),
partial(fbeta_score, beta=2), False, id='fbeta_score'),
pytest.param(sk_confusion_matrix, confusion_matrix, False, id='confusion_matrix'),
pytest.param(sk_roc_curve, roc, True, id='roc'),
pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id='precision_recall_curve'),
pytest.param(sk_roc_auc_score, auroc, True, id='auroc')
])
def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
def test_against_sklearn(sklearn_metric, torch_metric, only_binary):
"""Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# iterate over different label counts in predictions and target
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]:
# for metrics with only_binary=False, we try out different combinations of number
# of labels in pred and target (also test binary)
# for metrics with only_binary=True, target is always binary and pred will be
# (unnormalized) class probabilities
class_comb = [(5, 2)] if only_binary else [(10, 10), (5, 10), (10, 5), (2, 2)]
for n_cls_pred, n_cls_target in class_comb:
pred = torch.randint(n_cls_pred, (300,), device=device)
target = torch.randint(n_cls_target, (300,), device=device)

sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy())
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target)
assert torch.allclose(sk_score, pl_score)

# if multi output
if isinstance(sk_score, tuple):
sk_score = [torch.tensor(sk_s.copy(), dtype=torch.float, device=device) for sk_s in sk_score]
for sk_s, pl_s in zip(sk_score, pl_score):
assert torch.allclose(sk_s, pl_s.float())
else:
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
assert torch.allclose(sk_score, pl_score)


@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted'])
Expand Down Expand Up @@ -384,6 +402,15 @@ def test_iou(half_ones, reduction, ignore_index, expected):
assert torch.allclose(iou_val, expected, atol=1e-9)


@pytest.mark.parametrize('metric', [auroc])
def test_error_on_multiclass_input(metric):
""" check that these metrics raise an error if they are used for multiclass problems """
pred = torch.randint(0, 10, (100, ))
target = torch.randint(0, 10, (100, ))
with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"):
_ = metric(pred, target)


# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see
# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our
# `absent_score`.
Expand Down Expand Up @@ -428,6 +455,8 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))


# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
Expand All @@ -450,7 +479,3 @@ def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, ex
reduction=reduction,
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))


# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
4 changes: 2 additions & 2 deletions tests/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def test_average_precision(pos_label):
assert isinstance(ap, torch.Tensor)


@pytest.mark.parametrize('pos_label', [1, 2])
@pytest.mark.parametrize('pos_label', [0, 1])
def test_auroc(pos_label):
auroc = AUROC(pos_label=pos_label)
assert auroc.name == 'auroc'

pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 1, 0, 1])
area = auroc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(area, torch.Tensor)

Expand Down

0 comments on commit b1347c9

Please sign in to comment.