Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification metrics overhaul: stat scores (3/n) #4839

Merged
merged 175 commits into from
Dec 30, 2020
Merged
Show file tree
Hide file tree
Changes from 166 commits
Commits
Show all changes
175 commits
Select commit Hold shift + click to select a range
6959ea0
Add stuff
tadejsv Nov 24, 2020
0679015
Change metrics documentation layout
tadejsv Nov 24, 2020
35627b5
Add stuff
tadejsv Nov 24, 2020
0282f3c
Add stat scores
tadejsv Nov 24, 2020
55fdaaf
Change testing utils
tadejsv Nov 24, 2020
35f8320
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 24, 2020
dd05912
Merge branch 'cls_metrics_input_formatting' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
5cbf56a
Replace len(*.shape) with *.ndim
tadejsv Nov 24, 2020
9c33d0b
More descriptive error message for input formatting
tadejsv Nov 24, 2020
6562205
Replace movedim with permute
tadejsv Nov 24, 2020
b97aef2
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 24, 2020
74261f7
Merge branch 'cls_metrics_input_formatting' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
cbbc769
PEP 8 compliance
tadejsv Nov 24, 2020
33166c5
WIP
tadejsv Nov 24, 2020
801abe8
Add reduce_scores function
tadejsv Nov 24, 2020
fb181ed
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
fbebd34
Temporarily add back legacy class_reduce
tadejsv Nov 24, 2020
b3d1b8b
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 24, 2020
f45fc81
Division with float
tadejsv Nov 24, 2020
3fdef40
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Nov 24, 2020
452df32
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 24, 2020
9d44a26
PEP 8 compliance
tadejsv Nov 24, 2020
82c3460
Merge branch 'cls_metrics_stat_scores' into cls_metrics_precision_recall
tadejsv Nov 24, 2020
5ce7cd9
Remove precision recall
tadejsv Nov 24, 2020
3b70270
Replace movedim with permute
tadejsv Nov 24, 2020
f1ae7b2
Add back tests
tadejsv Nov 24, 2020
04a5066
Add empty newlines
tadejsv Nov 25, 2020
9dc7bea
Add empty line
tadejsv Nov 25, 2020
a9640f6
Fix permute
tadejsv Nov 25, 2020
692392c
Fix some issues with old versions of PyTorch
tadejsv Nov 25, 2020
a04a71e
Style changes in error messages
tadejsv Nov 25, 2020
eaac5d7
More error message style improvements
tadejsv Nov 25, 2020
c1108f0
Fix typo in docs
tadejsv Nov 25, 2020
277769b
Add more descriptive variable names in utils
tadejsv Nov 25, 2020
4849298
Change internal var names
tadejsv Nov 25, 2020
22906a4
Merge remote-tracking branch 'upstream/master' into cls_metrics_input…
tadejsv Nov 25, 2020
1034a71
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
ebcdbeb
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Nov 25, 2020
02bd636
Break down error checking for inputs into separate functions
tadejsv Nov 25, 2020
f97145b
Remove the (N, ..., C) option in MD-MC
tadejsv Nov 25, 2020
536feaf
Simplify select_topk
tadejsv Nov 25, 2020
4241d7c
Remove detach for inputs
tadejsv Nov 25, 2020
99d3c81
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
86d6c4d
Fix typos
tadejsv Nov 25, 2020
54c98a0
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
bb11677
Merge branch 'master' into cls_metrics_input_formatting
teddykoker Nov 25, 2020
bdc4111
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 25, 2020
cde3997
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 26, 2020
05a54da
Update docs/source/metrics.rst
tadejsv Nov 26, 2020
9a43a5e
Minor error message changes
tadejsv Nov 26, 2020
3f4ad3c
Update pytorch_lightning/metrics/utils.py
tadejsv Nov 26, 2020
a654e6a
Reuse case from validation in formatting
tadejsv Nov 26, 2020
7b2ef2b
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 26, 2020
16ab8f7
Refactor code in _input_format_classification
tadejsv Nov 27, 2020
558276f
Merge branch 'master' into cls_metrics_input_formatting
tchaton Nov 27, 2020
ecffe18
Small improvements
tadejsv Nov 27, 2020
a907ade
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 27, 2020
725c7dd
PEP 8
tadejsv Nov 27, 2020
41ad0b7
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
ca13e76
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
ede2c7f
Update docs/source/metrics.rst
tadejsv Nov 27, 2020
c6e4de4
Update pytorch_lightning/metrics/classification/utils.py
tadejsv Nov 27, 2020
201d0de
Apply suggestions from code review
tadejsv Nov 27, 2020
f08edbc
Alphabetical reordering of regression metrics
tadejsv Nov 27, 2020
523bae3
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 27, 2020
db24fae
Merge branch 'master' into cls_metrics_input_formatting
Borda Nov 27, 2020
35e3eff
Change default value of top_k and add error checking
tadejsv Nov 28, 2020
dd6f8ea
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Nov 28, 2020
c28aadf
Extract basic validation into separate function
tadejsv Nov 28, 2020
4bfc688
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Nov 28, 2020
323285e
Update to new top_k default
tadejsv Nov 28, 2020
0cb0eac
Update desciption of parameters in input formatting
tadejsv Nov 29, 2020
28acf4c
Merge branch 'master' into cls_metrics_input_formatting
tchaton Nov 30, 2020
8e7a85a
Apply suggestions from code review
tadejsv Nov 30, 2020
829155e
Check that probabilities in preds sum to 1 (for MC)
tadejsv Nov 30, 2020
768879d
Fix coverage
tadejsv Nov 30, 2020
e4d88e2
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Dec 1, 2020
eeded45
Split accuracy and hamming loss
tadejsv Dec 1, 2020
b49cfdc
Remove old redundant accuracy
tadejsv Dec 1, 2020
15ef14d
Merge branch 'master' into cls_metrics_input_formatting
teddykoker Dec 2, 2020
479114f
Merge branch 'master' into cls_metrics_stat_scores
tchaton Dec 3, 2020
3d8f584
Merge branch 'master' into cls_metrics_accuracy
tchaton Dec 3, 2020
1568970
Merge branch 'master' into cls_metrics_input_formatting
tchaton Dec 3, 2020
a9fa730
Merge with master and resolve conflicts
tadejsv Dec 6, 2020
44ad276
Merge branch 'master' into cls_metrics_input_formatting
Borda Dec 6, 2020
96d40c8
Minor changes
tadejsv Dec 6, 2020
cca430a
Merge branch 'cls_metrics_input_formatting' of github.com:tadejsv/pyt…
tadejsv Dec 6, 2020
b0bde16
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Dec 6, 2020
627d99a
Fix imports
tadejsv Dec 6, 2020
de3defb
Improve docstring descriptions
tadejsv Dec 6, 2020
218ff56
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 6, 2020
c24d47b
Fix imports
tadejsv Dec 6, 2020
f3c47f9
Fix edge case and simplify testing
tadejsv Dec 6, 2020
a7e91a9
Merge branch 'cls_metrics_input_formatting' into cls_metrics_accuracy
tadejsv Dec 6, 2020
94c1af6
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 6, 2020
b7ced6e
Fix docs
tadejsv Dec 6, 2020
e91e564
PEP8
tadejsv Dec 6, 2020
798ec03
Reorder imports
tadejsv Dec 6, 2020
ccdc421
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 6, 2020
658bfb1
Add top_k parameter
tadejsv Dec 6, 2020
7217924
Merge remote-tracking branch 'upstream/master' into cls_metrics_accuracy
tadejsv Dec 7, 2020
a7c143e
Update changelog
tadejsv Dec 7, 2020
531ae33
Update docstring
tadejsv Dec 7, 2020
2eba226
Merge branch 'master' into cls_metrics_accuracy
tadejsv Dec 7, 2020
a66cf31
Update docstring
tadejsv Dec 7, 2020
e93f83e
Merge branch 'cls_metrics_accuracy' of github.com:tadejsv/pytorch-lig…
tadejsv Dec 7, 2020
89b09f8
Reverse formatting changes for tests
tadejsv Dec 7, 2020
e715437
Change parameter order
tadejsv Dec 7, 2020
d5daec8
Remove formatting changes 2/2
tadejsv Dec 7, 2020
c820060
Remove formatting 3/3
tadejsv Dec 7, 2020
b576de0
.
tadejsv Dec 7, 2020
dae341b
Improve description of top_k parameter
tadejsv Dec 7, 2020
43136b2
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 7, 2020
b2d2b71
Apply suggestions from code review
Borda Dec 7, 2020
9b2a399
Apply suggestions from code review
tadejsv Dec 7, 2020
0952df2
Remove unneeded assert
tadejsv Dec 7, 2020
c7fe698
Update pytorch_lightning/metrics/functional/accuracy.py
tadejsv Dec 7, 2020
e2bc0ab
Remove unneeded assert
tadejsv Dec 7, 2020
acbd1ca
Merge branch 'cls_metrics_accuracy' of github.com:tadejsv/pytorch-lig…
tadejsv Dec 7, 2020
8801f8a
Explicit checking of parameter values
tadejsv Dec 7, 2020
c32b36e
Apply suggestions from code review
Borda Dec 7, 2020
0314c7d
Apply suggestions from code review
tadejsv Dec 7, 2020
152cadf
Fix top_k checking
tadejsv Dec 7, 2020
022d6a6
PEP8
tadejsv Dec 7, 2020
9efc963
Don't check dist_sync in test
tadejsv Dec 8, 2020
d992f7d
add back check_dist_sync_on_step
tadejsv Dec 8, 2020
a726060
Make sure half-precision inputs are transformed (#5013)
tadejsv Dec 8, 2020
93c5d02
Fix typo
tadejsv Dec 8, 2020
0813055
Rename hamming loss to hamming distance
tadejsv Dec 8, 2020
6bf714b
Fix tests for half precision
tadejsv Dec 8, 2020
d12f1d6
Fix docs underline length
tadejsv Dec 8, 2020
a55cb46
Fix doc undeline length
tadejsv Dec 8, 2020
d75eec3
Merge branch 'master' into cls_metrics_accuracy
justusschock Dec 8, 2020
6b3b057
Replace mdmc_accuracy parameter with subset_accuracy
tadejsv Dec 8, 2020
6f218d4
Merge branch 'cls_metrics_accuracy' of github.com:tadejsv/pytorch-lig…
tadejsv Dec 8, 2020
98cb5f4
Update changelog
tadejsv Dec 8, 2020
778aeae
Merge branch 'cls_metrics_accuracy' into cls_metrics_stat_scores
tadejsv Dec 8, 2020
d129ccb
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 21, 2020
7eb1457
Fix unwanted accuracy change
tadejsv Dec 21, 2020
207a762
Enable top_k for ML prob inputs
tadejsv Dec 21, 2020
3b79348
Test that default threshold is 0.5
tadejsv Dec 21, 2020
b609b35
Fix typo
tadejsv Dec 21, 2020
633e3ff
Update top_k description in helpers
tadejsv Dec 23, 2020
82879d0
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 23, 2020
103dfc6
updates
tadejsv Dec 23, 2020
9be50aa
Update styling and add back tests
tadejsv Dec 23, 2020
d3f851c
Remove excess spaces
tadejsv Dec 23, 2020
1612139
fix torch.where for old versions
tadejsv Dec 23, 2020
ca03c4a
fix linting
tadejsv Dec 23, 2020
aea4c66
Update docstring
tadejsv Dec 23, 2020
7b4dcc1
Fix docstring
tadejsv Dec 23, 2020
9cd07a8
Apply suggestions from code review (mostly docs)
tadejsv Dec 24, 2020
a713fc7
Default threshold to None, accept only (0,1)
tadejsv Dec 24, 2020
075ed53
Change wrong threshold message
tadejsv Dec 24, 2020
c289f0c
Improve documentation and add tests
tadejsv Dec 25, 2020
aae5141
Merge branch 'tests_mprc' into cls_metrics_stat_scores
tadejsv Dec 25, 2020
e665f89
Add back ddp tests
tadejsv Dec 27, 2020
16d29bf
Change stat reduce method and default
tadejsv Dec 27, 2020
7e8fb8e
Remove DDP tests and fix doctests
tadejsv Dec 28, 2020
d1a4eff
Fix doctest
tadejsv Dec 28, 2020
01e8e63
Update changelog
tadejsv Dec 28, 2020
3e58244
Refactoring
tadejsv Dec 28, 2020
475c706
Fix typo
tadejsv Dec 28, 2020
d387eb1
Refactor
tadejsv Dec 28, 2020
d2a92e8
Increase coverage
tadejsv Dec 28, 2020
c178cb6
Fix linting
tadejsv Dec 28, 2020
8bf6cf1
Consistent use of backticks
tadejsv Dec 29, 2020
b2fcd55
Merge remote-tracking branch 'upstream/release/1.2-dev' into cls_metr…
tadejsv Dec 29, 2020
169fc7c
Fix too long line in docs
tadejsv Dec 29, 2020
21551f1
Apply suggestions from code review
tadejsv Dec 29, 2020
e52fa9c
Fix deprecation test
tadejsv Dec 29, 2020
85d6e3a
Fix deprecation test
tadejsv Dec 29, 2020
3461159
Default threshold back to 0.5
tadejsv Dec 29, 2020
fe48912
Minor documentation fixes
tadejsv Dec 30, 2020
c2c45f1
Add types to tests
tadejsv Dec 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- `HammingDistance` metric to compute the hamming distance (loss) ([#4838](https://github.com/PyTorchLightning/pytorch-lightning/pull/4838))

- `StatScores` metric to compute the number of true positives, false positives, true negatives and false negatives ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

### Changed

- `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

### Deprecated

- `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

### Removed

Expand Down
68 changes: 62 additions & 6 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,62 @@ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of differ
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])

In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label. For example, if both predictions and targets are 1d
binary tensors. Or it could be the other way around, you want to treat binary/multi-label
inputs as 2-class (multi-dimensional) multi-class inputs.

Using the ``is_multiclass`` parameter
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

In some cases, you might have inputs which appear to be (multi-dimensional) multi-class
but are actually binary/multi-label - for example, if both predictions and targets are
integer (binary) tensors. Or it could be the other way around, you want to treat
binary/multi-label inputs as 2-class (multi-dimensional) multi-class inputs.

For these cases, the metrics where this distinction would make a difference, expose the
``is_multiclass`` argument.
``is_multiclass`` argument. Let's see how this is used on the example of
:class:`~pytorch_lightning.metrics.classification.StatScores` metric.

First, let's consider the case with label predictions with 2 classes, which we want to
treat as binary.

.. testcode::

from pytorch_lightning.metrics.functional import stat_scores

# These inputs are supposed to be binary, but appear as multi-class
preds = torch.tensor([0, 1, 0])
target = torch.tensor([1, 1, 0])

As you can see below, by default the inputs are treated
as multi-class. We can set ``is_multiclass=False`` to treat the inputs as binary -
which is the same as converting the predictions to float beforehand.

.. doctest::

>>> stat_scores(preds, target, reduce='macro', num_classes=2)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=1, is_multiclass=False)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds.float(), target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])

Next, consider the opposite example: inputs are binary (as predictions are probabilities),
but we would like to treat them as 2-class multi-class, to obtain the metric for both classes.

.. testcode::

preds = torch.tensor([0.2, 0.7, 0.3])
target = torch.tensor([1, 1, 0])

In this case we can set ``is_multiclass=True``, to treat the inputs as multi-class.

.. doctest::

>>> stat_scores(preds, target, reduce='macro', num_classes=1)
tensor([[1, 0, 1, 1, 2]])
>>> stat_scores(preds, target, reduce='macro', num_classes=2, is_multiclass=True)
tensor([[1, 1, 1, 0, 1],
[1, 0, 1, 1, 2]])


Class Metrics (Classification)
------------------------------
Expand Down Expand Up @@ -323,6 +372,13 @@ ROC
:noindex:


StatScores
~~~~~~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.StatScores
:noindex:


Functional Metrics (Classification)
-----------------------------------

Expand Down Expand Up @@ -444,7 +500,7 @@ select_topk [func]
stat_scores [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.stat_scores
.. autofunction:: pytorch_lightning.metrics.functional.stat_scores
:noindex:


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ROC,
FBeta,
F1,
StatScores
)

from pytorch_lightning.metrics.regression import ( # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall # noqa: F401
from pytorch_lightning.metrics.classification.precision_recall_curve import PrecisionRecallCurve # noqa: F401
from pytorch_lightning.metrics.classification.roc import ROC # noqa: F401
from pytorch_lightning.metrics.classification.stat_scores import StatScores # noqa: F401
38 changes: 20 additions & 18 deletions pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

class Accuracy(Metric):
r"""
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`_:
Computes `Accuracy <https://en.wikipedia.org/wiki/Accuracy_and_precision>`__:

.. math::
\text{Accuracy} = \frac{1}{N}\sum_i^N 1(y_i = \hat{y}_i)
Expand All @@ -43,7 +43,8 @@ class Accuracy(Metric):
Args:
threshold:
Threshold probability value for transforming probability predictions to binary
`(0,1)` predictions, in the case of binary or multi-label inputs.
(0,1) predictions, in the case of binary or multi-label inputs. If not set it
defaults to 0.5.
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
top_k:
Number of highest probability predictions considered to find the correct label, relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
Expand All @@ -54,17 +55,18 @@ class Accuracy(Metric):
Whether to compute subset accuracy for multi-label and multi-dimensional
multi-class inputs (has no effect for other input types).

For multi-label inputs, if the parameter is set to `True`, then all labels for
each sample must be correctly predicted for the sample to count as correct. If it
is set to `False`, then all labels are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).

For multi-dimensional multi-class inputs, if the parameter is set to `True`, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to `False`, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.
- For multi-label inputs, if the parameter is set to ``True``, then all labels for
each sample must be correctly predicted for the sample to count as correct. If it
is set to ``False``, then all labels are counted separately - this is equivalent to
flattening inputs beforehand (i.e. ``preds = preds.flatten()`` and same for ``target``).

- For multi-dimensional multi-class inputs, if the parameter is set to ``True``, then all
sub-sample (on the extra axis) must be correct for the sample to be counted as correct.
If it is set to ``False``, then all sub-samples are counter separately - this is equivalent,
in the case of label predictions, to flattening the inputs beforehand (i.e.
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.

compute_on_step:
Forward only calls ``update()`` and return None if this is set to False.
dist_sync_on_step:
Expand Down Expand Up @@ -95,7 +97,7 @@ class Accuracy(Metric):

def __init__(
self,
threshold: float = 0.5,
threshold: Optional[float] = None,
top_k: Optional[int] = None,
subset_accuracy: bool = False,
compute_on_step: bool = True,
Expand All @@ -113,11 +115,11 @@ def __init__(
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

if not 0 <= threshold <= 1:
raise ValueError("The `threshold` should lie in the [0,1] interval.")
if threshold is not None and not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")

if top_k is not None and top_k <= 0:
raise ValueError("The `top_k` should be an integer larger than 1.")
if top_k is not None and (not isinstance(top_k, int) or top_k <= 0):
raise ValueError(f"The `top_k` should be an integer larger than 0, got {top_k}")

self.threshold = threshold
self.top_k = top_k
Expand Down
56 changes: 34 additions & 22 deletions pytorch_lightning/metrics/classification/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold
if preds_float and (preds.min() < 0 or preds.max() > 1):
raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.")

if threshold > 1 or threshold < 0:
raise ValueError("The `threshold` should be a probability in [0,1].")
if not 0 < threshold < 1:
raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

if is_multiclass is False and target.max() > 1:
raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.")
Expand Down Expand Up @@ -181,13 +181,19 @@ def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes


def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool):
if "multi-class" not in case or not preds_float:
raise ValueError(
"You have set `top_k` above 1, but your data is not (multi-dimensional) multi-class"
" with probability predictions."
)
if case == "binary":
raise ValueError("You can not use `top_k` parameter with binary data.")
if not isinstance(top_k, int) or top_k <= 0:
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("The `top_k` has to be an integer larger than 0.")
if not preds_float:
raise ValueError("You have set `top_k`, but you do not have probability predictions.")
if is_multiclass is False:
raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.")
if case == "multi-label" and is_multiclass:
raise ValueError(
"If you want to transform multi-label data to 2 class multi-dimensional"
"multi-class data using `is_multiclass=True`, you can not use `top_k`."
)
if top_k >= implied_classes:
raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.")

Expand Down Expand Up @@ -216,9 +222,9 @@ def _check_classification_inputs(
When ``num_classes`` is not specified in these cases, consistency of the highest target
value against ``C`` dimension is checked for (multi-dimensional) multi-class cases.

If ``top_k`` is set (not None) for inputs which are not (multi-dimensional) multi class
with probabilities, then an error is raised. Similarly if ``top_k`` is set to a number
that is higher than or equal to the ``C`` dimension of ``preds``.
If ``top_k`` is set (not None) for inputs which do not have probability predictions (and
are not binary), then an error is raised. Similarly if ``top_k`` is set to a number that
is higher than or equal to the ``C`` dimension of ``preds``.
tadejsv marked this conversation as resolved.
Show resolved Hide resolved

Preds and target tensors are expected to be squeezed already - all dimensions should be
greater than 1, except perhaps the first one (N).
Expand All @@ -228,17 +234,18 @@ def _check_classification_inputs(
target: Tensor with ground truth labels, always integers (labels)
threshold:
Threshold probability value for transforming probability predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5
(0,1) predictions, in the case of binary or multi-label inputs.
num_classes:
Number of classes. If not explicitly set, the number of classes will be infered
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
tensor, where applicable.
top_k:
Number of highest probability entries for each sample to convert to 1s - relevant
only for (multi-dimensional) multi-class inputs with probability predictions. The
default value (``None``) will be interepreted as 1 for these inputs.
only for inputs with probability predictions. The default value (``None``) will be
interepreted as 1 for these inputs. If this parameter is set for multi-label inputs,
it will take precedence over threshold.

Should be left unset (``None``) for all other types of inputs.
Should be left unset (``None``) for inputs with label predictions.
is_multiclass:
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
Used only in certain special cases, where you want to treat inputs as a different type
than what they appear to be (see :ref:`metrics: Input types` documentation section for
Expand Down Expand Up @@ -294,7 +301,7 @@ def _check_classification_inputs(
_check_num_classes_ml(num_classes, is_multiclass, implied_classes)

# Check that top_k is consistent
if top_k:
if top_k is not None:
_check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point())

return case
Expand Down Expand Up @@ -364,7 +371,8 @@ def _input_format_classification(
target: Tensor with ground truth labels, always integers (labels)
threshold:
Threshold probability value for transforming probability predictions to binary
(0,1) predictions, in the case of binary or multi-label inputs. Default: 0.5
(0,1) predictions, in the case of binary or multi-label inputs. If not set it
defaults to 0.5.
Borda marked this conversation as resolved.
Show resolved Hide resolved
num_classes:
Number of classes. If not explicitly set, the number of classes will be infered
either from the shape of inputs, or the maximum label in the ``target`` and ``preds``
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -410,6 +418,9 @@ def _input_format_classification(
if preds.dtype == torch.float16:
preds = preds.float()

# Let threshold default to 0.5 if not set
threshold = 0.5 if threshold is None else threshold
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
tadejsv marked this conversation as resolved.
Show resolved Hide resolved

case = _check_classification_inputs(
preds,
target,
Expand All @@ -419,21 +430,22 @@ def _input_format_classification(
top_k=top_k,
)

top_k = top_k if top_k else 1

if case in ["binary", "multi-label"]:
if case in ["binary", "multi-label"] and not top_k:
preds = (preds >= threshold).int()
num_classes = num_classes if not is_multiclass else 2

if case == "multi-label" and top_k:
tadejsv marked this conversation as resolved.
Show resolved Hide resolved
preds = select_topk(preds, top_k)

if "multi-class" in case or is_multiclass:
if preds.is_floating_point():
num_classes = preds.shape[1]
preds = select_topk(preds, top_k)
preds = select_topk(preds, top_k or 1)
else:
num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1
preds = to_onehot(preds, max(2,num_classes))
preds = to_onehot(preds, max(2, num_classes))

target = to_onehot(target, max(2,num_classes))
target = to_onehot(target, max(2, num_classes))

if is_multiclass is False:
preds, target = preds[:, 1, ...], target[:, 1, ...]
Expand Down
Loading