Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Use zero as default value of thrs in metrics. #341

Merged
merged 4 commits into from
Jul 18, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions mmcls/core/evaluation/eval_metrics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from numbers import Number

import numpy as np
import torch

Expand Down Expand Up @@ -36,7 +38,7 @@ def calculate_confusion_matrix(pred, target):
return confusion_matrix


def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
def precision_recall_f1(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision, recall and f1 score according to the prediction and
target.

Expand All @@ -49,8 +51,8 @@ def precision_recall_f1(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.

Returns:
float | np.array | list[float | np.array]: Precision, recall, f1 score.
Expand All @@ -74,16 +76,14 @@ class are returned. If 'macro', calculate metrics for each class,
(f'pred and target should be torch.Tensor or np.ndarray, '
f'but got {type(pred)} and {type(target)}.')

if thrs is None:
thrs = 0.0
if isinstance(thrs, float):
if isinstance(thrs, Number):
thrs = (thrs, )
return_single = True
elif isinstance(thrs, tuple):
return_single = False
else:
raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.')
f'thrs should be a number or tuple, but got {type(thrs)}.')

label = np.indices(pred.shape)[1]
pred_label = np.argsort(pred, axis=1)[:, -1]
Expand Down Expand Up @@ -119,7 +119,7 @@ class are returned. If 'macro', calculate metrics for each class,
return precisions, recalls, f1_scores


def precision(pred, target, average_mode='macro', thrs=None):
def precision(pred, target, average_mode='macro', thrs=0.):
"""Calculate precision according to the prediction and target.

Args:
Expand All @@ -131,8 +131,8 @@ def precision(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.

Returns:
float | np.array | list[float | np.array]: Precision.
Expand All @@ -147,7 +147,7 @@ class are returned. If 'macro', calculate metrics for each class,
return precisions


def recall(pred, target, average_mode='macro', thrs=None):
def recall(pred, target, average_mode='macro', thrs=0.):
"""Calculate recall according to the prediction and target.

Args:
Expand All @@ -159,8 +159,8 @@ def recall(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.

Returns:
float | np.array | list[float | np.array]: Recall.
Expand All @@ -175,7 +175,7 @@ class are returned. If 'macro', calculate metrics for each class,
return recalls


def f1_score(pred, target, average_mode='macro', thrs=None):
def f1_score(pred, target, average_mode='macro', thrs=0.):
"""Calculate F1 score according to the prediction and target.

Args:
Expand All @@ -187,8 +187,8 @@ def f1_score(pred, target, average_mode='macro', thrs=None):
class are returned. If 'macro', calculate metrics for each class,
and find their unweighted mean.
Defaults to 'macro'.
thrs (float | tuple[float], optional): Predictions with scores under
the thresholds are considered negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.

Returns:
float | np.array | list[float | np.array]: F1 score.
Expand Down
13 changes: 10 additions & 3 deletions mmcls/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def evaluate(self,
average_mode = metric_options.get('average_mode', 'macro')

if 'accuracy' in metrics:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
if thrs is not None:
acc = accuracy(results, gt_labels, topk=topk, thrs=thrs)
else:
acc = accuracy(results, gt_labels, topk=topk)
if isinstance(topk, tuple):
eval_results_ = {
f'accuracy_top-{k}': a
Expand All @@ -182,8 +185,12 @@ def evaluate(self,

precision_recall_f1_keys = ['precision', 'recall', 'f1_score']
if len(set(metrics) & set(precision_recall_f1_keys)) != 0:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
if thrs is not None:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode, thrs=thrs)
else:
precision_recall_f1_values = precision_recall_f1(
results, gt_labels, average_mode=average_mode)
for key, values in zip(precision_recall_f1_keys,
precision_recall_f1_values):
if key in metrics:
Expand Down
25 changes: 11 additions & 14 deletions mmcls/models/losses/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from numbers import Number

import numpy as np
import torch
import torch.nn as nn


def accuracy_numpy(pred, target, topk=1, thrs=None):
if thrs is None:
thrs = 0.0
if isinstance(thrs, float):
def accuracy_numpy(pred, target, topk=1, thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.')
f'thrs should be a number or tuple, but got {type(thrs)}.')

res = []
maxk = max(topk)
Expand All @@ -36,17 +36,15 @@ def accuracy_numpy(pred, target, topk=1, thrs=None):
return res


def accuracy_torch(pred, target, topk=1, thrs=None):
if thrs is None:
thrs = 0.0
if isinstance(thrs, float):
def accuracy_torch(pred, target, topk=1, thrs=0.):
if isinstance(thrs, Number):
thrs = (thrs, )
res_single = True
elif isinstance(thrs, tuple):
res_single = False
else:
raise TypeError(
f'thrs should be float or tuple, but got {type(thrs)}.')
f'thrs should be a number or tuple, but got {type(thrs)}.')

res = []
maxk = max(topk)
Expand All @@ -68,7 +66,7 @@ def accuracy_torch(pred, target, topk=1, thrs=None):
return res


def accuracy(pred, target, topk=1, thrs=None):
def accuracy(pred, target, topk=1, thrs=0.):
"""Calculate accuracy according to the prediction and target.

Args:
Expand All @@ -77,9 +75,8 @@ def accuracy(pred, target, topk=1, thrs=None):
topk (int | tuple[int]): If the predictions in ``topk``
matches the target, the predictions will be regarded as
correct ones. Defaults to 1.
thrs (float, optional): thrs (float | tuple[float], optional):
Predictions with scores under the thresholds are considered
negative. Default to None.
thrs (Number | tuple[Number], optional): Predictions with scores under
the thresholds are considered negative. Default to 0.

Returns:
float | list[float] | list[list[float]]: If the input ``topk`` is a
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_dataset_evaluation():
assert eval_results['f1_score'] == pytest.approx(
(1 / 2 + 0 + 1 / 2) / 3 * 100.0)
assert eval_results['accuracy'] == pytest.approx(2 / 6 * 100)
# thrs must be a float, tuple or None
# thrs must be a number or tuple
with pytest.raises(TypeError):
eval_results = dataset.evaluate(
fake_results,
Expand Down