Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recsys Metric : MRR #2843

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ Complete list of metrics
Frequency
Loss
MeanAbsoluteError
MeanReciprocalRank
MeanPairwiseDistance
MeanSquaredError
metric.Metric
Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ignite.metrics.gan.inception_score import InceptionScore
from ignite.metrics.loss import Loss
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_reciprocal_rank import MeanReciprocalRank
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
Expand Down
84 changes: 84 additions & 0 deletions ignite/metrics/mean_reciprocal_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import List, Callable, Union, Sequence

import torch

from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["MeanReciprocalRank"]

class MeanReciprocalRank(Metric):
r"""Calculate `the mean reciprocal rank (MRR) <https://en.wikipedia.org/wiki/Mean_reciprocal_rank>`_.

.. math:: \text{MRR} = \frac{1}{\lvert Q \rvert} \sum{i=1}^(\lvert Q \rvert) \frac{1}{rank_{i}}

where :math:`rank_{i}` refers to the rank position of the first relevant document for the i-th query.

Args:
k: the k in “top-k”.
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.

``y_pred`` and ``y`` should have the same shape.

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. include:: defaults.rst
:start-after: :orphan:

.. testcode::

metric = MeanReciprocalRank()
metric.attach(default_evaluator, 'mrr')
preds = torch.tensor([
[1, 2, 4, 1],
[2, 3, 1, 5],
[1, 3, 5, 1],
[1, 5, 1 ,11]
])
target = preds * 0.75
state = default_evaluator.run([[preds, target]])
print(state.metrics['mrr'])

.. testoutput::


"""

def __init__(
self,
k: int = 5,
output_transform: Callable = lambda x: torch.mean(x, 0),
device: Union[str, torch.device] = torch.device("cpu")
):
super(MeanReciprocalRank, self).__init__(output_transform=output_transform, device=device)
self._k = k

@reinit__is_reduced
def reset(self):
self._relevance = torch.empty(0)

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred, y = output[0].detach(), output[1].detach()
_, topk_idx = y_pred.topk(self._k, dim=-1)
relevance = y.take_along_dim(topk_idx, dim=-1)
self._relevance = torch.cat([self._relevance, relevance], dim=-1)

@sync_all_reduce("_sum", "_num_examples")
Copy link
Collaborator

@vfdev-5 vfdev-5 Feb 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is incorrect, but let's fix it later once we have a good understanding of how metric is computed

def compute(self) -> float:
first_relevant_positions = self._relevance.argmax(dim=-1) + 1
valid_mask = (self._relevance.sum(dim=-1) > 0)

return valid_mask/first_relevant_positions