Skip to content

Commit

Permalink
[Metrics] Unification of FBeta (#4656)
Browse files Browse the repository at this point in the history
* implementation

* init files

* more stable reduction

* add tests

* docs

* remove old implementation

* pep8

* changelog

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

(cherry picked from commit 6831ba9)
  • Loading branch information
SkafteNicki authored and Borda committed Nov 23, 2020
1 parent c8afcab commit 278b9a9
Show file tree
Hide file tree
Showing 13 changed files with 421 additions and 216 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647))


- Added `F1` class metric ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656))


### Changed

- Consistently use `step=trainer.global_step` in `LearningRateMonitor` independently of `logging_interval` ([#4376](https://github.com/PyTorchLightning/pytorch-lightning/pull/4376))
Expand All @@ -20,6 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Metric states are no longer as default added to `state_dict` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/))


- Renamed class metric `Fbeta` >> `FBeta` ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656))


### Deprecated


Expand Down
18 changes: 12 additions & 6 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,16 @@ Recall
.. autoclass:: pytorch_lightning.metrics.classification.Recall
:noindex:

Fbeta
FBeta
~~~~~

.. autoclass:: pytorch_lightning.metrics.classification.Fbeta
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
:noindex:

F1
~~

.. autoclass:: pytorch_lightning.metrics.classification.F1
:noindex:

Regression Metrics
Expand Down Expand Up @@ -325,17 +331,17 @@ dice_score [func]
:noindex:


f1_score [func]
f1 [func]
~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.f1_score
.. autofunction:: pytorch_lightning.metrics.functional.f1
:noindex:


fbeta_score [func]
fbeta [func]
~~~~~~~~~~~~~~~~~~

.. autofunction:: pytorch_lightning.metrics.functional.classification.fbeta_score
.. autofunction:: pytorch_lightning.metrics.functional.fbeta
:noindex:


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
Accuracy,
Precision,
Recall,
Fbeta
F1,
FBeta,
)

from pytorch_lightning.metrics.regression import (
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.metrics.classification.precision_recall import Precision, Recall
from pytorch_lightning.metrics.classification.f_beta import Fbeta
from pytorch_lightning.metrics.classification.f_beta import FBeta, F1
136 changes: 100 additions & 36 deletions pytorch_lightning/metrics/classification/f_beta.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import functools
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union
from collections.abc import Mapping, Sequence
from collections import namedtuple
from typing import Any, Optional

import torch
from torch import nn

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.reduction import class_reduce
from pytorch_lightning.metrics.classification.precision_recall import _input_format
from pytorch_lightning.metrics.functional.f_beta import (
_fbeta_update,
_fbeta_compute
)


class Fbeta(Metric):
class FBeta(Metric):
"""
Computes f_beta metric.
Expand All @@ -51,7 +48,10 @@ class Fbeta(Metric):
average:
* `'micro'` computes metric globally
* `'macro'` computes metric for each class and then takes the mean
* `'macro'` computes metric for each class and uniformly averages them
* `'weighted'` computes metric for each class and does a weighted-average,
where each class is weighted by their support (accounts for class imbalance)
* `None` computes and returns the metric per class
multilabel: If predictions are from multilabel classification.
compute_on_step:
Expand All @@ -64,29 +64,28 @@ class Fbeta(Metric):
Example:
>>> from pytorch_lightning.metrics import Fbeta
>>> from pytorch_lightning.metrics import FBeta
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f_beta = Fbeta(num_classes=3, beta=0.5)
>>> f_beta = FBeta(num_classes=3, beta=0.5)
>>> f_beta(preds, target)
tensor(0.3333)
"""

def __init__(
self,
num_classes: int = 1,
beta: float = 1.,
num_classes: int,
beta: float = 1.0,
threshold: float = 0.5,
average: str = 'micro',
average: str = "micro",
multilabel: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group,
)

self.num_classes = num_classes
Expand All @@ -95,8 +94,10 @@ def __init__(
self.average = average
self.multilabel = multilabel

assert self.average in ('micro', 'macro'), \
"average passed to the function must be either `micro` or `macro`"
allowed_average = ("micro", "macro", "weighted", None)
if self.average not in allowed_average:
raise ValueError('Argument `average` expected to be one of the following:'
f' {allowed_average} but got {self.average}')

self.add_state("true_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")
Expand All @@ -110,25 +111,88 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel)
true_positives, predicted_positives, actual_positives = _fbeta_update(
preds, target, self.num_classes, self.threshold, self.multilabel
)

self.true_positives += torch.sum(preds * target, dim=1)
self.predicted_positives += torch.sum(preds, dim=1)
self.actual_positives += torch.sum(target, dim=1)
self.true_positives += true_positives
self.predicted_positives += predicted_positives
self.actual_positives += actual_positives

def compute(self):
def compute(self) -> torch.Tensor:
"""
Computes accuracy over state.
Computes fbeta over state.
"""
if self.average == 'micro':
precision = self.true_positives.sum().float() / (self.predicted_positives.sum())
recall = self.true_positives.sum().float() / (self.actual_positives.sum())
return _fbeta_compute(self.true_positives, self.predicted_positives,
self.actual_positives, self.beta, self.average)


class F1(FBeta):
"""
Computes F1 metric. F1 metrics correspond to a equally weighted average of the
precision and recall scores.
Works with binary, multiclass, and multilabel data.
Accepts logits from a model output or integer class values in prediction.
Works with multi-dimensional preds and target.
Forward accepts
- ``preds`` (float or long tensor): ``(N, ...)`` or ``(N, C, ...)`` where C is the number of classes
- ``target`` (long tensor): ``(N, ...)``
If preds and target are the same shape and preds is a float tensor, we use the ``self.threshold`` argument.
This is the case for binary and multi-label logits.
If preds has an extra dimension as in the case of multi-class scores we perform an argmax on ``dim=1``.
Args:
num_classes: Number of classes in the dataset.
threshold:
Threshold value for binary or multi-label logits. default: 0.5
elif self.average == 'macro':
precision = self.true_positives.float() / (self.predicted_positives)
recall = self.true_positives.float() / (self.actual_positives)
average:
* `'micro'` computes metric globally
* `'macro'` computes metric for each class and uniformly averages them
* `'weighted'` computes metric for each class and does a weighted-average,
where each class is weighted by their support (accounts for class imbalance)
* `None` computes and returns the metric per class
num = (1 + self.beta ** 2) * precision * recall
denom = self.beta ** 2 * precision + recall
multilabel: If predictions are from multilabel classification.
compute_on_step:
Forward only calls ``update()`` and returns None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
Example:
>>> from pytorch_lightning.metrics import F1
>>> target = torch.tensor([0, 1, 2, 0, 1, 2])
>>> preds = torch.tensor([0, 2, 1, 0, 0, 1])
>>> f1 = F1(num_classes=3)
>>> f1(preds, target)
tensor(0.3333)
"""

return class_reduce(num=num, denom=denom, weights=None, class_reduction='macro')
def __init__(
self,
num_classes: int = 1,
beta: float = 1.0,
threshold: float = 0.5,
average: str = "micro",
multilabel: bool = False,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
num_classes=num_classes,
beta=1.0,
threshold=threshold,
average=average,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
)
37 changes: 7 additions & 30 deletions pytorch_lightning/metrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,7 @@
import torch
from torch import nn
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.utils import to_onehot, METRIC_EPS


def _input_format(num_classes: int, preds: torch.Tensor, target: torch.Tensor, threshold=0.5, multilabel=False):
if not (len(preds.shape) == len(target.shape) or len(preds.shape) == len(target.shape) + 1):
raise ValueError(
"preds and target must have same number of dimensions, or one additional dimension for preds"
)

if len(preds.shape) == len(target.shape) + 1:
# multi class probabilites
preds = torch.argmax(preds, dim=1)

if len(preds.shape) == len(target.shape) and preds.dtype == torch.long and num_classes > 1 and not multilabel:
# multi-class
preds = to_onehot(preds, num_classes=num_classes)
target = to_onehot(target, num_classes=num_classes)

elif len(preds.shape) == len(target.shape) and preds.dtype == torch.float:
# binary or multilabel probablities
preds = (preds >= threshold).long()

# transpose class as first dim and reshape
if len(preds.shape) > 1:
preds = preds.transpose(1, 0)
target = target.transpose(1, 0)

return preds.reshape(num_classes, -1), target.reshape(num_classes, -1)
from pytorch_lightning.metrics.utils import to_onehot, METRIC_EPS, _input_format_classification_one_hot


class Precision(Metric):
Expand Down Expand Up @@ -126,7 +99,9 @@ def __init__(
self.add_state("predicted_positives", default=torch.zeros(num_classes), dist_reduce_fx="sum")

def update(self, preds: torch.Tensor, target: torch.Tensor):
preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel)
preds, target = _input_format_classification_one_hot(
self.num_classes, preds, target, self.threshold, self.multilabel
)

# multiply because we are counting (1, 1) pair for true positives
self.true_positives += torch.sum(preds * target, dim=1)
Expand Down Expand Up @@ -221,7 +196,9 @@ def update(self, preds: torch.Tensor, target: torch.Tensor):
preds: Predictions from model
target: Ground truth values
"""
preds, target = _input_format(self.num_classes, preds, target, self.threshold, self.multilabel)
preds, target = _input_format_classification_one_hot(
self.num_classes, preds, target, self.threshold, self.multilabel
)

# multiply because we are counting (1, 1) pair for true positives
self.true_positives += torch.sum(preds * target, dim=1)
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
average_precision,
confusion_matrix,
dice_score,
f1_score,
fbeta_score,
multiclass_precision_recall_curve,
multiclass_roc,
precision,
Expand All @@ -44,3 +42,4 @@
from pytorch_lightning.metrics.functional.mean_squared_log_error import mean_squared_log_error
from pytorch_lightning.metrics.functional.psnr import psnr
from pytorch_lightning.metrics.functional.ssim import ssim
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1
Loading

0 comments on commit 278b9a9

Please sign in to comment.