Skip to content

Commit

Permalink
[Fix] Use zero as default value of thrs in metrics. (open-mmlab#341)
Browse files Browse the repository at this point in the history
* Use zero as default value of `thrs` in metrics. And it accepcts a number
instead of float now.

* Fix unit test comment

* Don't pass thrs if no thrs.
  • Loading branch information
mzr1996 authored Jul 18, 2021
1 parent 679fc52 commit 18e6ffb
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 34 deletions.
32 changes: 16 additions & 16 deletions mmcls/core/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from numbers import Number

import numpy as np
import torch

Expand Down Expand Up @@ -36,7 +38,7 @@ def calculate_confusion_matrix(pred, target):
return confusion_matrix


def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
def precision_recall_f1(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision, recall and f1 score according to the prediction and
target.
Expand All @@ -49,8 +51,8 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
tuple: tuple containing precision, recall, f1 score.
Expand Down Expand Up @@ -78,16 +80,14 @@ class are returned. If 'macro', calculate metrics for each class,
(f'pred and target should be torch.Tensor or np.ndarray, '
f'but got {type(pred)} and {type(target)}.')

if thrs is None:
thrs = 0.0
if isinstance(thrs, float):
if isinstance(thrs, Number):
thrs = (thrs, )
return_single = True
elif isinstance(thrs, tuple):
return_single = False
else:
raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.')
f'thrs should be a number or tuple, but got {type(thrs)}.')

label = np.indices(pred.shape)[1]
pred_label = np.argsort(pred, axis=1)[:, -1]
Expand Down Expand Up @@ -123,7 +123,7 @@ class are returned. If 'macro', calculate metrics for each class,
return precisions, recalls, f1_scores


def precision(pred, target, average_mode='macro', thrs=None):
def precision(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision according to the prediction and target.
Args:
Expand All @@ -135,8 +135,8 @@ def precision(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
float | np.array | list[float | np.array]: Precision.
Expand All @@ -153,7 +153,7 @@ class are returned. If 'macro', calculate metrics for each class,
return precisions


def recall(pred, target, average_mode='macro', thrs=None):
def recall(pred, target, average_mode='macro', thrs=0.):
"""Calculate recall according to the prediction and target.
Args:
Expand All @@ -165,8 +165,8 @@ def recall(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
float | np.array | list[float | np.array]: Recall.
Expand All @@ -183,7 +183,7 @@ class are returned. If 'macro', calculate metrics for each class,
return recalls


def f1_score(pred, target, average_mode='macro', thrs=None):
def f1_score(pred, target, average_mode='macro', thrs=0.):
"""Calculate F1 score according to the prediction and target.
Args:
Expand All @@ -195,8 +195,8 @@ def f1_score(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
float | np.array | list[float | np.array]: F1 score.
Expand Down
13 changes: 10 additions & 3 deletions mmcls/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ def evaluate(self,
average_mode = metric_options.get('average_mode', 'macro')

if 'accuracy' in metrics:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
if thrs is not None:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
else:
acc = accuracy(results, gt_labels, topk=topk)
if isinstance(topk, tuple):
eval_results_ = {
f'accuracy_top-{k}': a
Expand All @@ -183,8 +186,12 @@ def evaluate(self,

precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
if thrs is not None:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
else:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode)
for key, values in zip(precision_recall_f1_keys,
precision_recall_f1_values):
if key in metrics:
Expand Down
25 changes: 11 additions & 14 deletions mmcls/models/losses/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from numbers import Number

import numpy as np
import torch
import torch.nn as nn


def accuracy_numpy(pred, target, topk=1, thrs=None):
if thrs is None:
thrs = 0.0
if isinstance(thrs, float):
def accuracy_numpy(pred, target, topk=1, thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.')
f'thrs should be a number or tuple, but got {type(thrs)}.')

res = []
maxk = max(topk)
Expand All @@ -36,17 +36,15 @@ def accuracy_numpy(pred, target, topk=1, thrs=None):
return res


def accuracy_torch(pred, target, topk=1, thrs=None):
if thrs is None:
thrs = 0.0
if isinstance(thrs, float):
def accuracy_torch(pred, target, topk=1, thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.')
f'thrs should be a number or tuple, but got {type(thrs)}.')

res = []
maxk = max(topk)
Expand All @@ -68,7 +66,7 @@ def accuracy_torch(pred, target, topk=1, thrs=None):
return res


def accuracy(pred, target, topk=1, thrs=None):
def accuracy(pred, target, topk=1, thrs=0.):
"""Calculate accuracy according to the prediction and target.
Args:
Expand All @@ -77,9 +75,8 @@ def accuracy(pred, target, topk=1, thrs=None):
topk (int | tuple[int]): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thrs (float, optional): thrs (float | tuple[float], optional):
Predictions with scores under the thresholds are considered
negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.
Returns:
float | list[float] | list[list[float]]: Accuracy
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/test_datasets/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_dataset_evaluation():
assert eval_results['f1_score'] == pytest.approx(
(1 / 2 + 0 + 1 / 2) / 3 * 100.0)
assert eval_results['accuracy'] == pytest.approx(2 / 6 * 100)
# thrs must be a float, tuple or None
# thrs must be a number or tuple
with pytest.raises(TypeError):
eval_results = dataset.evaluate(
fake_results,
Expand Down

0 comments on commit 18e6ffb

Please sign in to comment.