Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metrics] Confusion matrix class interface #4348

Merged
merged 32 commits into from
Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
dd9c584
docs + precision + recall + f_beta + refactor
ananyahjha93 Oct 10, 2020
1ef3ef2
rebase
ananyahjha93 Oct 10, 2020
f7c5c2d
fixes
ananyahjha93 Oct 10, 2020
f810486
added missing file
ananyahjha93 Oct 10, 2020
3ca8123
docs
ananyahjha93 Oct 10, 2020
344a518
docs
ananyahjha93 Oct 10, 2020
f2f7ec9
extra import
ananyahjha93 Oct 10, 2020
385cf3b
add confusion matrix
SkafteNicki Oct 12, 2020
e83c7bc
add to docs
SkafteNicki Oct 12, 2020
0f1ed32
add test
SkafteNicki Oct 12, 2020
40501c5
pep8 + isort
SkafteNicki Oct 12, 2020
816c3a3
merge
SkafteNicki Oct 14, 2020
c8f1902
update tests
SkafteNicki Oct 14, 2020
d15b8df
Merge remote-tracking branch 'upstream/master' into metrics/confusion…
SkafteNicki Oct 25, 2020
46db09a
move util function
SkafteNicki Oct 25, 2020
a48624a
unify functional and class
SkafteNicki Oct 25, 2020
d626ac8
add to init
SkafteNicki Oct 25, 2020
8404be1
remove old implementation
SkafteNicki Oct 25, 2020
bbca887
update tests
SkafteNicki Oct 25, 2020
4cbb7f0
pep8
SkafteNicki Oct 25, 2020
ef485b8
add duplicate
SkafteNicki Oct 25, 2020
044d801
fix doctest
SkafteNicki Oct 25, 2020
c068ae7
Update pytorch_lightning/metrics/classification/confusion_matrix.py
SkafteNicki Oct 28, 2020
499ab93
Merge remote-tracking branch 'upstream/master' into metrics/confusion…
SkafteNicki Oct 28, 2020
8682c1a
changelog
SkafteNicki Oct 28, 2020
6b6909e
Merge branch 'master' into metrics/confusion_matrix
tchaton Oct 29, 2020
ccafd36
Merge branch 'master' into metrics/confusion_matrix
s-rog Oct 30, 2020
3bcc5fb
Merge branch 'master' into metrics/confusion_matrix
SkafteNicki Oct 30, 2020
c2140ad
bullet point args
SkafteNicki Oct 30, 2020
832f6c3
bullet docs
rohitgr7 Oct 30, 2020
70c7032
bullet docs
SkafteNicki Oct 30, 2020
a158720
Merge branch 'master' into metrics/confusion_matrix
SkafteNicki Oct 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ Fbeta
.. autoclass:: pytorch_lightning.metrics.classification.Fbeta
:noindex:

ConfusionMatrix
~~~~~~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
:noindex:

Regression Metrics
------------------

Expand Down Expand Up @@ -269,7 +275,7 @@ average_precision [func]
confusion_matrix [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.confusion_matrix
.. autofunction:: pytorch_lightning.metrics.functional.confusion_matrix
:noindex:


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
Accuracy,
Precision,
Recall,
Fbeta
Fbeta,
ConfusionMatrix
)

from pytorch_lightning.metrics.regression import (
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall
from pytorch_lightning.metrics.classification.f_beta import Fbeta
from pytorch_lightning.metrics.classification.confusion_matrix import ConfusionMatrix
19 changes: 2 additions & 17 deletions pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
from torch import nn
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.utils import _input_format_classification


class Accuracy(Metric):
Expand Down Expand Up @@ -60,7 +61,6 @@ class Accuracy(Metric):
tensor(0.5000)

"""

def __init__(
self,
threshold: float = 0.5,
Expand All @@ -79,21 +79,6 @@ def __init__(

self.threshold = threshold

def _input_format(self, preds: torch.Tensor, target: torch.Tensor):
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError(
"preds and target must have same number of dimensions, or one additional dimension for preds"
)

if len(preds.shape) == len(target.shape) + 1:
# multi class probabilites
preds = torch.argmax(preds, dim=1)

if len(preds.shape) == len(target.shape) and preds.dtype == torch.float:
# binary or multilabel probablities
preds = (preds >= self.threshold).long()
return preds, target

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Expand All @@ -102,7 +87,7 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
preds, target = self._input_format(preds, target)
preds, target = _input_format_classification(preds, target, self.threshold)
assert preds.shape == target.shape

self.correct += torch.sum(preds == target)
Expand Down
108 changes: 108 additions & 0 deletions pytorch_lightning/metrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

import torch

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.confusion_matrix import (
_confusion_matrix_update,
_confusion_matrix_compute
)


class ConfusionMatrix(Metric):
"""
Computes the confusion matrix. Works with binary, multiclass, and multilabel data.
Accepts logits from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.

Forward accepts

- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
This is the case for binary and multi-label logits.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

Args:
num_classes: Number of classes in the dataset.
normalize: normalization mode for confusion matrix. Default is None,
meaning no normalization. Choose between normalization over the
targets (`'true`', most commonly used), the predictions (`'pred`') or
over hole matrix (`'all'`)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
threshold:
Threshold value for binary or multi-label logits. default: 0.5
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Example:

>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
tensor([[2., 0.],
[1., 1.]])

"""
def __init__(
self,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):

super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
self.num_classes = num_classes
self.normalize = normalize
self.threshold = threshold

allowed_normalize = ('true', 'pred', 'all', None)
assert self.normalize in allowed_normalize, \
f"Argument average needs to one of the following: {allowed_normalize}"

self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold)
self.confmat += confmat

def compute(self):
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
"""
Computes confusion matrix
"""
return _confusion_matrix_compute(self.confmat, self.normalize)
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
auc,
auroc,
average_precision,
confusion_matrix,
dice_score,
f1_score,
fbeta_score,
Expand Down Expand Up @@ -44,3 +43,4 @@
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error
from pytorch_lightning.metrics.functional.psnr import psnr
from pytorch_lightning.metrics.functional.ssim import ssim
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix
42 changes: 0 additions & 42 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,48 +301,6 @@ def _confmat_normalize(cm):
return cm


def confusion_matrix(
pred: torch.Tensor,
target: torch.Tensor,
normalize: bool = False,
num_classes: Optional[int] = None
) -> torch.Tensor:
"""
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
in group i that were predicted in group j.

Args:
pred: estimated targets
target: ground truth labels
normalize: normalizes confusion matrix
num_classes: number of classes

Return:
Tensor, confusion matrix C [num_classes, num_classes ]

Example:

>>> x = torch.tensor([1, 2, 3])
>>> y = torch.tensor([0, 2, 3])
>>> confusion_matrix(x, y)
tensor([[0., 1., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
"""
num_classes = get_num_classes(pred, target, num_classes)

unique_labels = (target.view(-1) * num_classes + pred.view(-1)).to(torch.int)

bins = torch.bincount(unique_labels, minlength=num_classes ** 2)
cm = bins.reshape(num_classes, num_classes).squeeze().float()

if normalize:
cm = _confmat_normalize(cm)

return cm


def precision_recall(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down
93 changes: 93 additions & 0 deletions pytorch_lightning/metrics/functional/confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.metrics.utils import _input_format_classification


def _confusion_matrix_update(preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
threshold: float = 0.5) -> torch.Tensor:
preds, target = _input_format_classification(preds, target, threshold)
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long)
bins = torch.bincount(unique_mapping, minlength=num_classes ** 2)
confmat = bins.reshape(num_classes, num_classes)
return confmat


def _confusion_matrix_compute(confmat: torch.Tensor,
normalize: Optional[str] = None) -> torch.Tensor:
allowed_normalize = ('true', 'pred', 'all', None)
assert normalize in allowed_normalize, \
f"Argument average needs to one of the following: {allowed_normalize}"
confmat = confmat.float()
if normalize is not None:
if normalize == 'true':
cm = confmat / confmat.sum(axis=1, keepdim=True)
elif normalize == 'pred':
cm = confmat / confmat.sum(axis=0, keepdim=True)
elif normalize == 'all':
cm = confmat / confmat.sum()
nan_elements = cm[torch.isnan(cm)].nelement()
if nan_elements != 0:
cm[torch.isnan(cm)] = 0
rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.')
return cm
return confmat


def confusion_matrix(
preds: torch.Tensor,
target: torch.Tensor,
num_classes: int,
normalize: Optional[str] = None,
threshold: float = 0.5
) -> torch.Tensor:
"""
Computes the confusion matrix. Works with binary, multiclass, and multilabel data.
Accepts logits from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.

If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
This is the case for binary and multi-label logits.

If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.

Args:
preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or
``(N, C, ...)`` where C is the number of classes, tensor with logits/probabilities
target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels
num_classes: Number of classes in the dataset.
normalize: normalization mode for confusion matrix. Default is None,
meaning no normalization. Choose between normalization over the
targets (`'true`', most commonly used), the predictions (`'pred`') or
over hole matrix (`'all'`)
threshold:
Threshold value for binary or multi-label logits. default: 0.5

Example:

>>> from pytorch_lightning.metrics.functional import confusion_matrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confusion_matrix(preds, target, num_classes=2)
tensor([[2., 0.],
[1., 1.]])
"""
confmat = _confusion_matrix_update(preds, target, num_classes, threshold)
return _confusion_matrix_compute(confmat, normalize)
28 changes: 28 additions & 0 deletions pytorch_lightning/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,31 @@ def _check_same_shape(pred: torch.Tensor, target: torch.Tensor):
""" Check that predictions and target have the same shape, else raise error """
if pred.shape != target.shape:
raise RuntimeError('Predictions and targets are expected to have the same shape')


def _input_format_classification(preds: torch.Tensor, target: torch.Tensor, threshold: float):
""" Convert preds and target tensors into label tensors

Args:
preds: either tensor with labels, tensor with probabilities/logits or
multilabel tensor
target: tensor with ground true labels
threshold: float used for thresholding multilabel input

Returns:
preds: tensor with labels
target: tensor with labels
"""
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError(
"preds and target must have same number of dimensions, or one additional dimension for preds"
)

if len(preds.shape) == len(target.shape) + 1:
# multi class probabilites
preds = torch.argmax(preds, dim=1)

if len(preds.shape) == len(target.shape) and preds.dtype == torch.float:
# binary or multilabel probablities
preds = (preds >= threshold).long()
return preds, target
Loading