Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification metrics overhaul: precision & recall (4/n) #4842

Merged
merged 92 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from 87 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
6959ea0
Add stuff
tadejsv Nov 24, 2020
0679015
Change metrics documentation layout
tadejsv Nov 24, 2020
35627b5
Add stuff
tadejsv Nov 24, 2020
0282f3c
Add stat scores
tadejsv Nov 24, 2020
55fdaaf
Change testing utils
tadejsv Nov 24, 2020
35f8320
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 24, 2020
dd05912
Merge branch 'cls_metrics_input_formatting' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
5cbf56a
Replace len(*.shape) with *.ndim
tadejsv Nov 24, 2020
9c33d0b
More descriptive error message for input formatting
tadejsv Nov 24, 2020
6562205
Replace movedim with permute
tadejsv Nov 24, 2020
b97aef2
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 24, 2020
74261f7
Merge branch 'cls_metrics_input_formatting' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
cbbc769
PEP 8 compliance
tadejsv Nov 24, 2020
33166c5
WIP
tadejsv Nov 24, 2020
801abe8
Add reduce_scores function
tadejsv Nov 24, 2020
fb181ed
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
fbebd34
Temporarily add back legacy class_reduce
tadejsv Nov 24, 2020
b3d1b8b
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 24, 2020
f45fc81
Division with float
tadejsv Nov 24, 2020
3fdef40
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
452df32
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 24, 2020
9d44a26
PEP 8 compliance
tadejsv Nov 24, 2020
82c3460
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 24, 2020
5ce7cd9
Remove precision recall
tadejsv Nov 24, 2020
3b70270
Replace movedim with permute
tadejsv Nov 24, 2020
f1ae7b2
Add back tests
tadejsv Nov 24, 2020
04a5066
Add empty newlines
tadejsv Nov 25, 2020
2033063
Add precision recall back
tadejsv Nov 25, 2020
9dc7bea
Add empty line
tadejsv Nov 25, 2020
6ba5566
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 25, 2020
a9640f6
Fix permute
tadejsv Nov 25, 2020
9bc0f4c
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 25, 2020
692392c
Fix some issues with old versions of PyTorch
tadejsv Nov 25, 2020
a04a71e
Style changes in error messages
tadejsv Nov 25, 2020
eaac5d7
More error message style improvements
tadejsv Nov 25, 2020
c1108f0
Fix typo in docs
tadejsv Nov 25, 2020
277769b
Add more descriptive variable names in utils
tadejsv Nov 25, 2020
4849298
Change internal var names
tadejsv Nov 25, 2020
22906a4
Merge remote-tracking branch 'upstream/master' into cls_metrics_input…
tadejsv Nov 25, 2020
1034a71
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
ebcdbeb
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Nov 25, 2020
8040d02
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 25, 2020
96a1a70
Merge branch 'master' into cls_metrics_precision_recall
tchaton Dec 3, 2020
500d22f
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 30, 2020
cac0b85
Revert unwanted changes
tadejsv Dec 30, 2020
d043384
Revert unwanted changes pt 2
tadejsv Dec 30, 2020
d6bc69b
Update metrics interface
tadejsv Jan 3, 2021
d6559f2
Add top_k parameter
tadejsv Jan 3, 2021
0b8a2fd
Add back reduce function
tadejsv Jan 5, 2021
777f1de
Merge branch 'release/1.2-dev' into cls_metrics_precision_recall
Borda Jan 6, 2021
0314a62
Add stuff
tadejsv Jan 8, 2021
be60aa1
Merge branch 'cls_metrics_precision_recall' of github.com:tadejsv/pyt…
tadejsv Jan 8, 2021
a9b0b93
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Jan 8, 2021
e96a442
PEP3
tadejsv Jan 8, 2021
b6de27a
Add depreciation
tadejsv Jan 8, 2021
24adfe8
PEP8
tadejsv Jan 8, 2021
660d4b1
Deprecate param
tadejsv Jan 8, 2021
6b018d9
PEP8
tadejsv Jan 8, 2021
9fdfcf6
Fix and simplify testing for older PT versions
tadejsv Jan 9, 2021
0ad7368
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Jan 9, 2021
88fd8cc
Update Changelog
tadejsv Jan 9, 2021
6a7b86f
Remove redundant import
tadejsv Jan 9, 2021
df6365a
Add tests to increase coverage
tadejsv Jan 10, 2021
17c680d
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Jan 10, 2021
5e0dfbd
Remove zero_division
tadejsv Jan 13, 2021
69b3305
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Jan 13, 2021
5658ee5
fix zero_division
tadejsv Jan 13, 2021
6ab9002
Add zero_div + edge case tests
tadejsv Jan 13, 2021
571cdd8
Reorder cls metric args
tadejsv Jan 13, 2021
fff6b8b
Add back quotes for is_multiclass
tadejsv Jan 13, 2021
46d7363
Add precision_recall and tests
tadejsv Jan 13, 2021
b6e375d
PEP8
tadejsv Jan 13, 2021
0ef081b
Fix docs
tadejsv Jan 13, 2021
3d0c985
Fix docs
tadejsv Jan 13, 2021
92f7c5f
Merge branch 'release/1.2-dev' into cls_metrics_precision_recall
tadejsv Jan 13, 2021
e69a71a
Update
tadejsv Jan 14, 2021
f5e3676
Merge branch 'cls_metrics_precision_recall' of github.com:tadejsv/pyt…
tadejsv Jan 14, 2021
b6f6576
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Jan 14, 2021
b2bd166
Change precision_recall output
tadejsv Jan 14, 2021
5eac1f4
PEP8/isort
tadejsv Jan 14, 2021
de4fb12
Add method _get_final_stats
tadejsv Jan 14, 2021
3142557
Fix depr test
tadejsv Jan 14, 2021
5b85787
Merge branch 'release/1.2-dev' into cls_metrics_precision_recall
tadejsv Jan 14, 2021
e517456
Merge branch 'release/1.2-dev' into cls_metrics_precision_recall
tadejsv Jan 15, 2021
371f5bb
Add comment to deprecation tests
tadejsv Jan 15, 2021
b787661
Merge branch 'cls_metrics_precision_recall' of github.com:tadejsv/pyt…
tadejsv Jan 15, 2021
937ae12
isort
tadejsv Jan 15, 2021
ee6fea5
Merge branch 'release/1.2-dev' into cls_metrics_precision_recall
Borda Jan 15, 2021
728980e
Apply suggestions from code review
tadejsv Jan 15, 2021
7a2bf23
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Jan 17, 2021
7511443
Add typing to test
tadejsv Jan 17, 2021
67d9702
Add matc str to pytest.raises
tadejsv Jan 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))


- `Recall` and `Precision` metrics (and their functional counterparts `recall` and `precision`) can now be generalized to Recall@K and Precision@K with the use of `top_k` parameter ([#4842](https://github.com/PyTorchLightning/pytorch-lightning/pull/4842))



### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
10 changes: 5 additions & 5 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])


Using the ``is_multiclass`` parameter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Using the is_multiclass parameter
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label - for example, if both predictions and targets are
Expand Down Expand Up @@ -602,14 +602,14 @@ roc [func]
precision [func]
~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.precision
.. autofunction:: pytorch_lightning.metrics.functional.precision
:noindex:


precision_recall [func]
~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.precision_recall
.. autofunction:: pytorch_lightning.metrics.functional.precision_recall
:noindex:


Expand All @@ -623,7 +623,7 @@ precision_recall_curve [func]
recall [func]
~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.recall
.. autofunction:: pytorch_lightning.metrics.functional.recall
:noindex:

select_topk [func]
Expand Down
72 changes: 70 additions & 2 deletions pytorch_lightning/metrics/classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Tuple, Optional

import numpy as np
import torch

from pytorch_lightning.metrics.utils import to_onehot, select_topk
Expand Down Expand Up @@ -249,7 +250,7 @@ def _check_classification_inputs(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <metrics:Using the \\`\\`is_multiclass\\`\\` parameter>`
:ref:`documentation section <metrics:Using the is_multiclass parameter>`
for a more detailed explanation and examples.


Expand Down Expand Up @@ -375,7 +376,7 @@ def _input_format_classification(
is_multiclass:
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be. See the parameter's
:ref:`documentation section <metrics:Using the \\`\\`is_multiclass\\`\\` parameter>`
:ref:`documentation section <metrics:Using the is_multiclass parameter>`
for a more detailed explanation and examples.


Expand Down Expand Up @@ -437,3 +438,70 @@ def _input_format_classification(
preds, target = preds.squeeze(-1), target.squeeze(-1)

return preds.int(), target.int(), case


def _reduce_stat_scores(
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
numerator: torch.Tensor,
denominator: torch.Tensor,
weights: Optional[torch.Tensor],
average: str,
mdmc_average: Optional[str],
zero_division: int = 0,
) -> torch.Tensor:
"""
Reduces scores of type ``numerator/denominator`` or
``weights * (numerator/denominator)``, if ``average='weighted'``.

Args:
numerator: A tensor with numerator numbers.
denominator: A tensor with denominator numbers. If a denominator is
negative, the class will be ignored (if averaging), or its score
will be returned as ``nan`` (if ``average=None``).
If the denominator is zero, then ``zero_division`` score will be
used for those elements.
weights:
A tensor of weights to be used if ``average='weighted'``.
average:
The method to average the scores. Should be one of ``'micro'``, ``'macro'``,
``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior
corresponds to `sklearn averaging methods <https://scikit-learn.org/stable/modules/\
model_evaluation.html#multiclass-and-multilabel-classification>`__.
mdmc_average:
The method to average the scores if inputs were multi-dimensional multi-class.
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
Should be either ``'global'`` or ``'samplewise'``. If inputs were not
multi-dimensional multi-class, it should be ``None`` (default).
zero_division:
The value to use for the score if denominator equals zero.
"""
numerator, denominator = numerator.float(), denominator.float()
zero_div_mask = denominator == 0
ignore_mask = denominator < 0

if weights is None:
weights = torch.ones_like(denominator)
else:
weights = weights.float()

numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator)
denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator)
weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights)
tadejsv marked this conversation as resolved.
Show resolved Hide resolved

if average not in ["micro", "none", None]:
weights = weights / weights.sum(dim=-1, keepdim=True)

scores = weights * (numerator / denominator)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

# This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted'
scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores)
tadejsv marked this conversation as resolved.
Show resolved Hide resolved

if mdmc_average == "samplewise":
scores = scores.mean(dim=0)
ignore_mask = ignore_mask.sum(dim=0).bool()

if average in ["none", None]:
scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores)
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
else:
scores = scores.sum()

# raise ValueError
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
return scores
Loading