Skip to content

Commit

Permalink
Merge e6d49dc into 8352951
Browse files Browse the repository at this point in the history
  • Loading branch information
zzc98 authored Feb 21, 2023
2 parents 8352951 + e6d49dc commit d2c51c6
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 8 deletions.
5 changes: 4 additions & 1 deletion configs/_base_/datasets/inshop_bs32_448.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_dataloader = query_dataloader
val_evaluator = dict(type='RetrievalRecall', topk=1)
val_evaluator = [
dict(type='RetrievalRecall', topk=1),
dict(type='RetrievalAveragePrecision', topk=10),
]

test_dataloader = val_dataloader
test_evaluator = val_evaluator
6 changes: 3 additions & 3 deletions configs/arcface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ Recently, a popular line of research in face recognition is adopting margins in

### InShop

| Model | Pretrain | Params(M) | Flops(G) | Recall@1 | Config | Download |
| :------------: | :------------------------------------------------: | :-------: | :------: | :---: | :-----: | :-----: | :-----------------------------------------------: | :-------------------------------------------------: |
| Resnet50-ArcFace | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 31.69 | 16.48 | 90.18 | [config](./resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) | [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) |
| Model | Pretrain | Params(M) | Flops(G) | Recall@1 | mAP@10 | Config | Download |
| :--------------: | :----------------------------------------------------: | :-------: | :------: | :------: | :----: | :------------------------------------------: | :-----------------------------------------------------: |
| Resnet50-ArcFace | [ImageNet-21k-mill](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_3rdparty-mill_in21k_20220331-faac000b.pth) | 31.69 | 16.48 | 90.18 | 69.30 | [config](./resnet50-arcface_8xb32_inshop.py) | [model](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth) \| [log](https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.log) |

## Citation

Expand Down
1 change: 1 addition & 0 deletions configs/arcface/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Models:
- Dataset: InShop
Metrics:
Recall@1: 90.18
mAP@10: 69.30
Task: Metric Learning
Weights: https://download.openmmlab.com/mmclassification/v0/arcface/resnet50-arcface_inshop_20230202-b766fe7f.pth
Config: configs/arcface/resnet50-arcface_8xb32_inshop.py
1 change: 1 addition & 0 deletions docs/en/api/evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ Retrieval Metric
:template: classtemplate.rst

RetrievalRecall
RetrievalAveragePrecision
4 changes: 2 additions & 2 deletions mmcls/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric
from .retrieval import RetrievalRecall
from .retrieval import RetrievalAveragePrecision, RetrievalRecall
from .single_label import Accuracy, ConfusionMatrix, SingleLabelMetric
from .voc_multi_label import VOCAveragePrecision, VOCMultiLabelMetric

__all__ = [
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall'
'ConfusionMatrix', 'RetrievalRecall', 'RetrievalAveragePrecision'
]
219 changes: 218 additions & 1 deletion mmcls/evaluation/metrics/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class RetrievalRecall(BaseMetric):
.. code:: python
val/test_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
val_evaluator = dict(type='RetrievalRecall', topk=(1, 5))
test_evaluator = val_evaluator
"""
default_prefix: Optional[str] = 'retrieval'

Expand Down Expand Up @@ -185,6 +186,222 @@ def calculate(pred: Union[np.ndarray, torch.Tensor],
return results


@METRICS.register_module()
class RetrievalAveragePrecision(BaseMetric):
r"""Calculate the average precision for image retrieval.
Args:
topk (int, optional): Predictions with the k-th highest scores are
considered as positive.
mode (str, optional): The mode to calculate AP, choose from
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Defaults to None.
Note:
If the ``mode`` set to 'IR', use the stanford AP calculation of
information retrieval as in wikipedia page[1]; if set to 'integrate',
the method implemented integrates over the precision-recall curve
by averaging two adjacent precision points, then multiplying by the
recall step like mAP in Detection task. This is the convention for
the Revisited Oxford/Paris datasets[2].
References:
[1] `Wikipedia entry for the Average precision <https://en.wikipedia.
org/wiki/Evaluation_measures_(information_retrieval)#Average_precision>`_
[2] `The Oxford Buildings Dataset
<https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/>`_
Examples:
Use in code:
>>> import torch
>>> import numpy as np
>>> from mmcls.evaluation import RetrievalAveragePrecision
>>> # using index format inputs
>>> pred = [ torch.Tensor([idx for idx in range(100)]) ] * 3
>>> target = [[0, 3, 6, 8, 35], [1, 2, 54, 105], [2, 42, 205]]
>>> RetrievalAveragePrecision.calculate(pred, target, 10, True, True)
29.246031746031747
>>> # using tensor format inputs
>>> pred = np.array([np.linspace(0.95, 0.05, 10)] * 2)
>>> target = torch.Tensor([[1, 0, 1, 0, 0, 1, 0, 0, 1, 1]] * 2)
>>> RetrievalAveragePrecision.calculate(pred, target, 10)
62.222222222222214
Use in OpenMMLab config files:
.. code:: python
val_evaluator = dict(type='RetrievalAveragePrecision', topk=100)
test_evaluator = val_evaluator
"""

default_prefix: Optional[str] = 'retrieval'

def __init__(self,
topk: Optional[int] = None,
mode: Optional[str] = 'IR',
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
if topk is None or (isinstance(topk, int) and topk <= 0):
raise ValueError('`topk` must be a ingter larger than 0.')

mode_options = ['IR', 'integrate']
assert mode in mode_options, \
f'Invalid `mode` argument, please specify from {mode_options}.'

self.topk = topk
self.mode = mode
super().__init__(collect_device=collect_device, prefix=prefix)

def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]):
"""Process one batch of data and predictions.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
pred_label = data_sample['pred_label']
gt_label = data_sample['gt_label']

pred = pred_label['score'].clone()
if 'score' in gt_label:
target = gt_label['score'].clone()
else:
num_classes = pred_label['score'].size()[-1]
target = LabelData.label_to_onehot(gt_label['label'],
num_classes)

# Because the retrieval output logit vector will be much larger
# compared to the normal classification, to save resources, the
# evaluation results are computed each batch here and then reduce
# all results at the end.
result = RetrievalAveragePrecision.calculate(
pred.unsqueeze(0),
target.unsqueeze(0),
self.topk,
mode=self.mode)
self.results.append(result)

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (list): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
result_metrics = dict()
result_metrics[f'mAP@{self.topk}'] = np.mean(self.results).item()

return result_metrics

@staticmethod
def calculate(pred: Union[np.ndarray, torch.Tensor],
target: Union[np.ndarray, torch.Tensor],
topk: Optional[int] = None,
pred_indices: (bool) = False,
target_indices: (bool) = False,
mode: str = 'IR') -> float:
"""Calculate the average precision.
Args:
pred (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
target (torch.Tensor | np.ndarray | Sequence): The prediction
results. A :obj:`torch.Tensor` or :obj:`np.ndarray` with
shape ``(N, M)`` or a sequence of index/onehot
format labels.
topk (int, optional): Predictions with the k-th highest scores
are considered as positive.
pred_indices (bool): Whether the ``pred`` is a sequence of
category index labels. Defaults to False.
target_indices (bool): Whether the ``target`` is a sequence of
category index labels. Defaults to False.
mode (Optional[str]): The mode to calculate AP, choose from
'IR'(information retrieval) and 'integrate'. Defaults to 'IR'.
Note:
If the ``mode`` set to 'IR', use the stanford AP calculation of
information retrieval as in wikipedia page; if set to 'integrate',
the method implemented integrates over the precision-recall curve
by averaging two adjacent precision points, then multiplying by the
recall step like mAP in Detection task. This is the convention for
the Revisited Oxford/Paris datasets.
Returns:
float: the average precision of the query image.
References:
[1] `Wikipedia entry for Average precision(information_retrieval)
<https://en.wikipedia.org/wiki/Evaluation_measures_
(information_retrieval)#Average_precision>`_
[2] `The Oxford Buildings Dataset <https://www.robots.ox.ac.uk/
~vgg/data/oxbuildings/`_
"""
if topk is None or (isinstance(topk, int) and topk <= 0):
raise ValueError('`topk` must be a ingter larger than 0.')

mode_options = ['IR', 'integrate']
assert mode in mode_options, \
f'Invalid `mode` argument, please specify from {mode_options}.'

pred = _format_pred(pred, topk, pred_indices)
target = _format_target(target, target_indices)

assert len(pred) == len(target), (
f'Length of `pred`({len(pred)}) and `target` ({len(target)}) '
f'must be the same.')

num_samples = len(pred)
aps = np.zeros(num_samples)
for i, (sample_pred, sample_target) in enumerate(zip(pred, target)):
aps[i] = _calculateAp_for_sample(sample_pred, sample_target, mode)

return aps.mean()


def _calculateAp_for_sample(pred, target, mode):
pred = np.array(to_tensor(pred).cpu())
target = np.array(to_tensor(target).cpu())

num_preds = len(pred)

# TODO: use ``torch.isin`` in torch1.10.
positive_ranks = np.arange(num_preds)[np.in1d(pred, target)]

ap = 0
for i, rank in enumerate(positive_ranks):
if mode == 'IR':
precision = (i + 1) / (rank + 1)
ap += precision
elif mode == 'integrate':
# code are modified from https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/compute_ap.cpp # noqa:
old_precision = i / rank if rank > 0 else 1
cur_precision = (i + 1) / (rank + 1)
prediction = (old_precision + cur_precision) / 2
ap += prediction
ap = ap / len(target)

return ap * 100


def _format_pred(label, topk=None, is_indices=False):
"""format various label to List[indices]."""
if is_indices:
Expand Down
Loading

0 comments on commit d2c51c6

Please sign in to comment.