Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] AUROC error on multilabel + improved testing #3350

Merged
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* functional interface ([#3349](https://github.com/PyTorchLightning/pytorch-lightning/pull/3349))
* class based interface + tests ([#3358](https://github.com/PyTorchLightning/pytorch-lightning/pull/3358))

- Added error when AUROC metric is used for multiclass problems ([#3350](https://github.com/PyTorchLightning/pytorch-lightning/pull/3350))

### Changed

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,10 @@ class AUROC(TensorMetric):
Example:

>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> target = torch.tensor([0, 1, 1, 0])
>>> metric = AUROC()
>>> metric(pred, target)
tensor(0.3333)
tensor(0.5000)

"""

Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,12 @@ def roc(
Example:

>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> y = torch.tensor([0, 1, 1, 1])
>>> fpr, tpr, thresholds = roc(x, y)
>>> fpr
tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000])
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0., 0., 0., 1., 1.])
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])

Expand Down Expand Up @@ -682,12 +682,12 @@ def precision_recall_curve(
Example:

>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> target = torch.tensor([0, 1, 1, 0])
>>> precision, recall, thresholds = precision_recall_curve(pred, target)
>>> precision
tensor([0.3333, 0.0000, 0.0000, 1.0000])
tensor([0.6667, 0.5000, 0.0000, 1.0000])
>>> recall
tensor([1., 0., 0., 0.])
tensor([1.0000, 0.5000, 0.0000, 0.0000])
>>> thresholds
tensor([1, 2, 3])

Expand Down Expand Up @@ -858,10 +858,14 @@ def auroc(
Example:

>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 2, 2])
>>> y = torch.tensor([0, 1, 1, 0])
>>> auroc(x, y)
tensor(0.3333)
tensor(0.5000)
"""
if any(target > 1):
raise ValueError('AUROC metric is meant for binary classification, but'
' target tensor contains value different from 0 and 1.'
' Multiclass is currently not supported.')

@auc_decorator(reorder=True)
def _auroc(pred, target, sample_weight, pos_label):
Expand Down
61 changes: 43 additions & 18 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
f1_score as sk_f1_score,
fbeta_score as sk_fbeta_score,
confusion_matrix as sk_confusion_matrix,
roc_curve as sk_roc_curve,
roc_auc_score as sk_roc_auc_score,
precision_recall_curve as sk_precision_recall_curve
)

from pytorch_lightning import seed_everything
Expand All @@ -36,29 +39,44 @@
)


@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(sk_accuracy, accuracy, id='accuracy'),
pytest.param(partial(sk_jaccard_score, average='macro'), iou, id='iou'),
pytest.param(partial(sk_precision, average='micro'), precision, id='precision'),
pytest.param(partial(sk_recall, average='micro'), recall, id='recall'),
pytest.param(partial(sk_f1_score, average='micro'), f1_score, id='f1_score'),
pytest.param(partial(sk_fbeta_score, average='micro', beta=2), partial(fbeta_score, beta=2), id='fbeta_score'),
pytest.param(sk_confusion_matrix, confusion_matrix, id='confusion_matrix')
@pytest.mark.parametrize(['sklearn_metric', 'torch_metric', 'only_binary'], [
pytest.param(sk_accuracy, accuracy, False, id='accuracy'),
pytest.param(partial(sk_jaccard_score, average='macro'), iou, False, id='iou'),
pytest.param(partial(sk_precision, average='micro'), precision, False, id='precision'),
pytest.param(partial(sk_recall, average='micro'), recall, False, id='recall'),
pytest.param(partial(sk_f1_score, average='micro'), f1_score, False, id='f1_score'),
pytest.param(partial(sk_fbeta_score, average='micro', beta=2),
partial(fbeta_score, beta=2), False, id='fbeta_score'),
pytest.param(sk_confusion_matrix, confusion_matrix, False, id='confusion_matrix'),
pytest.param(sk_roc_curve, roc, True, id='roc'),
pytest.param(sk_precision_recall_curve, precision_recall_curve, True, id='precision_recall_curve'),
pytest.param(sk_roc_auc_score, auroc, True, id='auroc')
])
def test_against_sklearn(sklearn_metric, torch_metric):
"""Compare PL metrics to sklearn version."""
def test_against_sklearn(sklearn_metric, torch_metric, only_binary):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""Compare PL metrics to sklearn version. """
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# iterate over different label counts in predictions and target
for n_cls_pred, n_cls_target in [(10, 10), (5, 10), (10, 5)]:
# for metrics with only_binary=False, we try out different combinations of number
# of labels in pred and target (also test binary)
# for metrics with only_binary=True, target is always binary and pred will be
# (unnormalized) class probabilities
class_comb = [(5, 2)] if only_binary else [(10, 10), (5, 10), (10, 5), (2, 2)]
for n_cls_pred, n_cls_target in class_comb:
pred = torch.randint(n_cls_pred, (300,), device=device)
target = torch.randint(n_cls_target, (300,), device=device)

sk_score = sklearn_metric(target.cpu().detach().numpy(),
pred.cpu().detach().numpy())
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
pl_score = torch_metric(pred, target)
assert torch.allclose(sk_score, pl_score)

# if multi output
if isinstance(sk_score, tuple):
sk_score = [torch.tensor(sk_s.copy(), dtype=torch.float, device=device) for sk_s in sk_score]
for sk_s, pl_s in zip(sk_score, pl_score):
assert torch.allclose(sk_s, pl_s.float())
else:
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
assert torch.allclose(sk_score, pl_score)


@pytest.mark.parametrize('class_reduction', ['micro', 'macro', 'weighted'])
Expand Down Expand Up @@ -384,6 +402,15 @@ def test_iou(half_ones, reduction, ignore_index, expected):
assert torch.allclose(iou_val, expected, atol=1e-9)


@pytest.mark.parametrize('metric', [auroc])
def test_error_on_multiclass_input(metric):
""" check that these metrics raise an error if they are used for multiclass problems """
pred = torch.randint(0, 10, (100, ))
target = torch.randint(0, 10, (100, ))
with pytest.raises(ValueError, match="AUROC metric is meant for binary classification"):
_ = metric(pred, target)


# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see
# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our
# `absent_score`.
Expand Down Expand Up @@ -428,6 +455,8 @@ def test_iou_absent_score(pred, target, ignore_index, absent_score, num_classes,
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))


# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [
# Ignoring an index outside of [0, num_classes-1] should have no effect.
pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]),
Expand All @@ -450,7 +479,3 @@ def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, ex
reduction=reduction,
)
assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val))


# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
4 changes: 2 additions & 2 deletions tests/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def test_average_precision(pos_label):
assert isinstance(ap, torch.Tensor)


@pytest.mark.parametrize('pos_label', [1, 2])
@pytest.mark.parametrize('pos_label', [0, 1])
def test_auroc(pos_label):
auroc = AUROC(pos_label=pos_label)
assert auroc.name == 'auroc'

pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 2, 0, 1])
pred, target = torch.tensor([1, 2, 3, 4]), torch.tensor([1, 1, 0, 1])
area = auroc(pred=pred, target=target, sample_weight=[0.1, 0.2, 0.3, 0.4])
assert isinstance(area, torch.Tensor)

Expand Down