Skip to content

Commit

Permalink
[Metrics] Add multiclass auroc (#4236)
Browse files Browse the repository at this point in the history
* Add functional multiclass AUROC metric

* Add multiclass_auroc tests

* fixup! Add functional multiclass AUROC metric

* fixup! fixup! Add functional multiclass AUROC metric

* Add multiclass_auroc doc reference

* Update CHANGELOG

* formatting

* Shorter error message regex match in tests

* Set num classes as pytest parameter

* formatting

* Update CHANGELOG

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
(cherry picked from commit 38bb4e2)
  • Loading branch information
ddrevicky authored and SeanNaren committed Nov 10, 2020
1 parent 09185b6 commit 48ed664
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added PyTorch 1.7 Stable support ([#3821](https://github.com/PyTorchLightning/pytorch-lightning/pull/3821))
- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))

- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))

### Changed

- W&B log in sync with `Trainer` step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))
Expand Down
7 changes: 7 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,13 @@ auroc [func]
:noindex:


multiclass_auroc [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.multiclass_auroc
:noindex:


average_precision [func]
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
fbeta_score,
multiclass_precision_recall_curve,
multiclass_roc,
multiclass_auroc,
precision,
precision_recall,
precision_recall_curve,
Expand Down
61 changes: 59 additions & 2 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,13 +817,14 @@ def new_func(*args, **kwargs) -> torch.Tensor:

def multiclass_auc_decorator(reorder: bool = True) -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
@wraps(func_to_decorate)
def new_func(*args, **kwargs) -> torch.Tensor:
results = []
for class_result in func_to_decorate(*args, **kwargs):
x, y = class_result[:2]
results.append(auc(x, y, reorder=reorder))

return torch.cat(results)
return torch.stack(results)

return new_func

Expand Down Expand Up @@ -858,7 +859,7 @@ def auroc(
if any(target > 1):
raise ValueError('AUROC metric is meant for binary classification, but'
' target tensor contains value different from 0 and 1.'
' Multiclass is currently not supported.')
' Use `multiclass_auroc` for multi class classification.')

@auc_decorator(reorder=True)
def _auroc(pred, target, sample_weight, pos_label):
Expand All @@ -867,6 +868,62 @@ def _auroc(pred, target, sample_weight, pos_label):
return _auroc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label)


def multiclass_auroc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from multiclass
prediction scores
Args:
pred: estimated probabilities, with shape [N, C]
target: ground-truth labels, with shape [N,]
sample_weight: sample weights
num_classes: number of classes (default: None, computes automatically from data)
Return:
Tensor containing ROCAUC score
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_auroc(pred, target) # doctest: +NORMALIZE_WHITESPACE
tensor(0.6667)
"""
if not torch.allclose(pred.sum(dim=1), torch.tensor(1.0)):
raise ValueError(
"Multiclass AUROC metric expects the target scores to be"
" probabilities, i.e. they should sum up to 1.0 over classes")

if torch.unique(target).size(0) != pred.size(1):
raise ValueError(
f"Number of classes found in in 'target' ({torch.unique(target).size(0)})"
f" does not equal the number of columns in 'pred' ({pred.size(1)})."
" Multiclass AUROC is not defined when all of the classes do not"
" occur in the target labels.")

if num_classes is not None and num_classes != pred.size(1):
raise ValueError(
f"Number of classes deduced from 'pred' ({pred.size(1)}) does not equal"
f" the number of classes passed in 'num_classes' ({num_classes}).")

@multiclass_auc_decorator(reorder=False)
def _multiclass_auroc(pred, target, sample_weight, num_classes):
return multiclass_roc(pred, target, sample_weight, num_classes)

class_aurocs = _multiclass_auroc(pred=pred, target=target,
sample_weight=sample_weight,
num_classes=num_classes)
return torch.mean(class_aurocs)


def average_precision(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down
42 changes: 42 additions & 0 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
dice_score,
average_precision,
auroc,
multiclass_auroc,
precision_recall_curve,
roc,
auc,
Expand Down Expand Up @@ -316,6 +317,47 @@ def test_auroc(pred, target, expected):
assert score == expected


def test_multiclass_auroc():
with pytest.raises(ValueError,
match=r".*probabilities, i.e. they should sum up to 1.0 over classes"):
_ = multiclass_auroc(pred=torch.tensor([[0.9, 0.9],
[1.0, 0]]),
target=torch.tensor([0, 1]))

with pytest.raises(ValueError,
match=r".*not defined when all of the classes do not occur in the target.*"):
_ = multiclass_auroc(pred=torch.rand((4, 3)).softmax(dim=1),
target=torch.tensor([1, 0, 1, 0]))

with pytest.raises(ValueError,
match=r".*does not equal the number of classes passed in 'num_classes'.*"):
_ = multiclass_auroc(pred=torch.rand((5, 4)).softmax(dim=1),
target=torch.tensor([0, 1, 2, 2, 3]),
num_classes=6)


@pytest.mark.parametrize('n_cls', [2, 5, 10, 50])
def test_multiclass_auroc_against_sklearn(n_cls):
device = 'cuda' if torch.cuda.is_available() else 'cpu'

n_samples = 300
pred = torch.rand(n_samples, n_cls, device=device).softmax(dim=1)
target = torch.randint(n_cls, (n_samples,), device=device)
# Make sure target includes all class labels so that multiclass AUROC is defined
target[10:10 + n_cls] = torch.arange(n_cls)

pl_score = multiclass_auroc(pred, target)
# For the binary case, sklearn expects an (n_samples,) array of probabilities of
# the positive class
pred = pred[:, 1] if n_cls == 2 else pred
sk_score = sk_roc_auc_score(target.cpu().detach().numpy(),
pred.cpu().detach().numpy(),
multi_class="ovr")

sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
assert torch.allclose(sk_score, pl_score)


@pytest.mark.parametrize(['x', 'y', 'expected'], [
pytest.param([0, 1], [0, 1], 0.5),
pytest.param([1, 0], [0, 1], 0.5),
Expand Down

0 comments on commit 48ed664

Please sign in to comment.