-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Metrics] Confusion matrix class interface (#4348)
* docs + precision + recall + f_beta + refactor Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * rebase Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * fixes Co-authored-by: Teddy Koker <teddy.koker@gmail.com> * added missing file * docs * docs * extra import * add confusion matrix * add to docs * add test * pep8 + isort * update tests * move util function * unify functional and class * add to init * remove old implementation * update tests * pep8 * add duplicate * fix doctest * Update pytorch_lightning/metrics/classification/confusion_matrix.py Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * changelog * bullet point args * bullet docs * bullet docs Co-authored-by: ananyahjha93 <ananya@pytorchlightning.ai> Co-authored-by: Teddy Koker <teddy.koker@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Roger Shieh <55400948+s-rog@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
- Loading branch information
1 parent
20a8eaa
commit e0b856c
Showing
12 changed files
with
384 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
pytorch_lightning/metrics/classification/confusion_matrix.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Optional | ||
|
||
import torch | ||
|
||
from pytorch_lightning.metrics.metric import Metric | ||
from pytorch_lightning.metrics.functional.confusion_matrix import ( | ||
_confusion_matrix_update, | ||
_confusion_matrix_compute | ||
) | ||
|
||
|
||
class ConfusionMatrix(Metric): | ||
""" | ||
Computes the confusion matrix. Works with binary, multiclass, and multilabel data. | ||
Accepts logits from a model output or integer class values in prediction. | ||
Works with multi-dimensional preds and target. | ||
Forward accepts | ||
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes | ||
- ``target`` (long tensor): ``(N, ...)`` | ||
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. | ||
This is the case for binary and multi-label logits. | ||
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. | ||
Args: | ||
num_classes: Number of classes in the dataset. | ||
normalize: Normalization mode for confusion matrix. Choose from | ||
- ``None``: no normalization (default) | ||
- ``'true'``: normalization over the targets (most commonly used) | ||
- ``'pred'``: normalization over the predictions | ||
- ``'all'``: normalization over the whole matrix | ||
threshold: | ||
Threshold value for binary or multi-label logits. default: 0.5 | ||
compute_on_step: | ||
Forward only calls ``update()`` and return None if this is set to False. default: True | ||
dist_sync_on_step: | ||
Synchronize metric state across processes at each ``forward()`` | ||
before returning the value at the step. default: False | ||
process_group: | ||
Specify the process group on which synchronization is called. default: None (which selects the entire world) | ||
Example: | ||
>>> from pytorch_lightning.metrics import ConfusionMatrix | ||
>>> target = torch.tensor([1, 1, 0, 0]) | ||
>>> preds = torch.tensor([0, 1, 0, 0]) | ||
>>> confmat = ConfusionMatrix(num_classes=2) | ||
>>> confmat(preds, target) | ||
tensor([[2., 0.], | ||
[1., 1.]]) | ||
""" | ||
def __init__( | ||
self, | ||
num_classes: int, | ||
normalize: Optional[str] = None, | ||
threshold: float = 0.5, | ||
compute_on_step: bool = True, | ||
dist_sync_on_step: bool = False, | ||
process_group: Optional[Any] = None, | ||
): | ||
|
||
super().__init__( | ||
compute_on_step=compute_on_step, | ||
dist_sync_on_step=dist_sync_on_step, | ||
process_group=process_group, | ||
) | ||
self.num_classes = num_classes | ||
self.normalize = normalize | ||
self.threshold = threshold | ||
|
||
allowed_normalize = ('true', 'pred', 'all', None) | ||
assert self.normalize in allowed_normalize, \ | ||
f"Argument average needs to one of the following: {allowed_normalize}" | ||
|
||
self.add_state("confmat", default=torch.zeros(num_classes, num_classes), dist_reduce_fx="sum") | ||
|
||
def update(self, preds: torch.Tensor, target: torch.Tensor): | ||
""" | ||
Update state with predictions and targets. | ||
Args: | ||
preds: Predictions from model | ||
target: Ground truth values | ||
""" | ||
confmat = _confusion_matrix_update(preds, target, self.num_classes, self.threshold) | ||
self.confmat += confmat | ||
|
||
def compute(self) -> torch.Tensor: | ||
""" | ||
Computes confusion matrix | ||
""" | ||
return _confusion_matrix_compute(self.confmat, self.normalize) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from pytorch_lightning.utilities import rank_zero_warn | ||
from pytorch_lightning.metrics.utils import _input_format_classification | ||
|
||
|
||
def _confusion_matrix_update(preds: torch.Tensor, | ||
target: torch.Tensor, | ||
num_classes: int, | ||
threshold: float = 0.5) -> torch.Tensor: | ||
preds, target = _input_format_classification(preds, target, threshold) | ||
unique_mapping = (target.view(-1) * num_classes + preds.view(-1)).to(torch.long) | ||
bins = torch.bincount(unique_mapping, minlength=num_classes ** 2) | ||
confmat = bins.reshape(num_classes, num_classes) | ||
return confmat | ||
|
||
|
||
def _confusion_matrix_compute(confmat: torch.Tensor, | ||
normalize: Optional[str] = None) -> torch.Tensor: | ||
allowed_normalize = ('true', 'pred', 'all', None) | ||
assert normalize in allowed_normalize, \ | ||
f"Argument average needs to one of the following: {allowed_normalize}" | ||
confmat = confmat.float() | ||
if normalize is not None: | ||
if normalize == 'true': | ||
cm = confmat / confmat.sum(axis=1, keepdim=True) | ||
elif normalize == 'pred': | ||
cm = confmat / confmat.sum(axis=0, keepdim=True) | ||
elif normalize == 'all': | ||
cm = confmat / confmat.sum() | ||
nan_elements = cm[torch.isnan(cm)].nelement() | ||
if nan_elements != 0: | ||
cm[torch.isnan(cm)] = 0 | ||
rank_zero_warn(f'{nan_elements} nan values found in confusion matrix have been replaced with zeros.') | ||
return cm | ||
return confmat | ||
|
||
|
||
def confusion_matrix( | ||
preds: torch.Tensor, | ||
target: torch.Tensor, | ||
num_classes: int, | ||
normalize: Optional[str] = None, | ||
threshold: float = 0.5 | ||
) -> torch.Tensor: | ||
""" | ||
Computes the confusion matrix. Works with binary, multiclass, and multilabel data. | ||
Accepts logits from a model output or integer class values in prediction. | ||
Works with multi-dimensional preds and target. | ||
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument. | ||
This is the case for binary and multi-label logits. | ||
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``. | ||
Args: | ||
preds: (float or long tensor), Either a ``(N, ...)`` tensor with labels or | ||
``(N, C, ...)`` where C is the number of classes, tensor with logits/probabilities | ||
target: ``target`` (long tensor), tensor with shape ``(N, ...)`` with ground true labels | ||
num_classes: Number of classes in the dataset. | ||
normalize: Normalization mode for confusion matrix. Choose from | ||
- ``None``: no normalization (default) | ||
- ``'true'``: normalization over the targets (most commonly used) | ||
- ``'pred'``: normalization over the predictions | ||
- ``'all'``: normalization over the whole matrix | ||
threshold: | ||
Threshold value for binary or multi-label logits. default: 0.5 | ||
Example: | ||
>>> from pytorch_lightning.metrics.functional import confusion_matrix | ||
>>> target = torch.tensor([1, 1, 0, 0]) | ||
>>> preds = torch.tensor([0, 1, 0, 0]) | ||
>>> confusion_matrix(preds, target, num_classes=2) | ||
tensor([[2., 0.], | ||
[1., 1.]]) | ||
""" | ||
confmat = _confusion_matrix_update(preds, target, num_classes, threshold) | ||
return _confusion_matrix_compute(confmat, normalize) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.