Skip to content

Commit

Permalink
Corrected f_beta computation (#4183)
Browse files Browse the repository at this point in the history
* Update f_beta.py

Added METRIC_EPS in the denominator to avoid nan values in f_beta score.

* Update f_beta.py

Made changes flake8 compliant

* Update f_beta.py

Makes use of class_reduce for macro f_beta computation to avoid nans

* Update f_beta.py

Made flake8 compliant

* Corrected F beta computation

* Removed offset to make the computation precise
  • Loading branch information
abhinavg97 authored Oct 21, 2020
1 parent a4fa7f8 commit 5d1583d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
16 changes: 9 additions & 7 deletions pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import torch
from torch import nn
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.reduction import class_reduce
from pytorch_lightning.metrics.classification.precision_recall import _input_format
from pytorch_lightning.metrics.utils import METRIC_EPS


class Fbeta(Metric):
Expand Down Expand Up @@ -121,12 +121,14 @@ def compute(self):
Computes accuracy over state.
"""
if self.average == 'micro':
precision = self.true_positives.sum().float() / (self.predicted_positives.sum() + METRIC_EPS)
recall = self.true_positives.sum().float() / (self.actual_positives.sum() + METRIC_EPS)
precision = self.true_positives.sum().float() / (self.predicted_positives.sum())
recall = self.true_positives.sum().float() / (self.actual_positives.sum())

return (1 + self.beta ** 2) * (precision * recall) / (self.beta ** 2 * precision + recall)
elif self.average == 'macro':
precision = self.true_positives.float() / (self.predicted_positives + METRIC_EPS)
recall = self.true_positives.float() / (self.actual_positives + METRIC_EPS)
precision = self.true_positives.float() / (self.predicted_positives)
recall = self.true_positives.float() / (self.actual_positives)

return ((1 + self.beta ** 2) * (precision * recall) / (self.beta ** 2 * precision + recall)).mean()
num = (1 + self.beta ** 2) * precision * recall
denom = self.beta ** 2 * precision + recall

return class_reduce(num=num, denom=denom, weights=None, class_reduction='macro')
8 changes: 8 additions & 0 deletions tests/metrics/classification/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@
target=torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
)

# Generate edge multilabel edge case, where nothing matches (scores are undefined)
__temp_preds = torch.randint(high=2, size=(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES))
__temp_target = abs(__temp_preds - 1)

_multilabel_inputs_no_match = Input(
preds=__temp_preds,
target=__temp_target
)

_multiclass_prob_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, NUM_CLASSES),
Expand Down
2 changes: 2 additions & 0 deletions tests/metrics/classification/test_f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
_multidim_multiclass_inputs,
_multidim_multiclass_prob_inputs,
_multilabel_inputs,
_multilabel_inputs_no_match,
_multilabel_prob_inputs,
)
from tests.metrics.utils import NUM_CLASSES, THRESHOLD, MetricTester
Expand Down Expand Up @@ -87,6 +88,7 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0):
(_binary_inputs.preds, _binary_inputs.target, _sk_fbeta_binary, 1, False),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True),
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
(_multilabel_inputs_no_match.preds, _multilabel_inputs_no_match.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False),
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_fbeta_multiclass, NUM_CLASSES, False),
(
Expand Down

0 comments on commit 5d1583d

Please sign in to comment.