Skip to content

Commit

Permalink
Merge pull request #590 from guijiql/0.2.x
Browse files Browse the repository at this point in the history
FIX: bugs in evaluator
  • Loading branch information
chenyushuo authored Dec 24, 2020
2 parents 948d9d6 + 634bad6 commit 94d5bd8
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 29 deletions.
8 changes: 6 additions & 2 deletions recbole/evaluator/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @email : tsotfsk@outlook.com

# UPDATE
# @Time : 2020/10/21, 2020/12/9
# @Time : 2020/10/21, 2020/12/18
# @Author : Kaiyuan Li, Zhichao Feng
# @email : tsotfsk@outlook.com, fzcbupt@gmail.com

Expand Down Expand Up @@ -99,7 +99,7 @@ class IndividualEvaluator(BaseEvaluator):
"""
def __init__(self, config, metrics):
super().__init__(config, metrics)
pass
self._check_args()

def sample_collect(self, true_scores, pred_scores):
"""It is called when evaluation sample distribution is `uniform` or `popularity`.
Expand Down Expand Up @@ -127,3 +127,7 @@ def get_score_matrix(self, true_scores, pred_scores):
scores_matrix = self.sample_collect(true_scores, pred_scores)

return scores_matrix

def _check_args(self):
if self.full:
raise NotImplementedError('full sort can\'t use IndividualEvaluator')
59 changes: 42 additions & 17 deletions recbole/evaluator/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @email : tsotfsk@outlook.com

# UPDATE
# @Time : 2020/08/04, 2020/08/11, 2020/12/9
# @Time : 2020/08/04, 2020/08/11, 2020/12/18
# @Author : Kaiyuan Li, Yupeng Hou, Zhichao Feng
# @email : tsotfsk@outlook.com, houyupeng@ruc.edu.cn, fzcbupt@gmail.com

Expand Down Expand Up @@ -158,22 +158,40 @@ def get_user_pos_len_list(self, interaction, scores_tensor):
user_len_list = interaction.user_len_list
return pos_len_list, user_len_list

def get_pos_index(self, scores_tensor, pos_len_list, user_len_list):
"""get the index of positive items
def average_rank(self, scores):
"""Get the ranking of an ordered tensor, and take the average of the ranking for positions with equal values.
Args:
scores_tensor (tensor): the tensor of model output with size of `(N, )`
pos_len_list(list): number of positive items
user_len_list(list): number of all items
Args:
scores(tensor): an ordered tensor, with size of `(N, )`
Returns:
torch.Tensor: average_rank
Returns:
tensor: a matrix indicating whether the corresponding item is positive
Example:
>>> average_rank(tensor([[1,2,2,2,3,3,6],[2,2,2,2,4,4,5]]))
tensor([[1.0000, 3.0000, 3.0000, 3.0000, 5.5000, 5.5000, 7.0000],
[2.5000, 2.5000, 2.5000, 2.5000, 5.0000, 6.5000, 6.5000]])
"""
scores_matrix = self.get_score_matrix(scores_tensor, user_len_list)
_, n_index = torch.sort(scores_matrix, dim=-1, descending=True)
pos_index = (n_index < pos_len_list.reshape(-1, 1))
return pos_index
Reference:
https://github.com/scipy/scipy/blob/v0.17.1/scipy/stats/stats.py#L5262-L5352
"""
length, width = scores.shape
device = scores.device
true_tensor = torch.full((length, 1), True, dtype=torch.bool, device=device)

obs = torch.cat([true_tensor, scores[:, 1:] != scores[:, :-1]], dim=1)
# bias added to dense
bias = torch.arange(0, length, device=device).repeat(width).reshape(width, -1). \
transpose(1, 0).reshape(-1)
dense = obs.view(-1).cumsum(0) + bias

# cumulative counts of each unique value
count = torch.where(torch.cat([obs, true_tensor], dim=1))[1]
# get average rank
avg_rank = .5 * (count[dense] + count[dense - 1] + 1).view(length, -1)

return avg_rank

def collect(self, interaction, scores_tensor):
"""collect the rank intermediate result of one batch, this function mainly implements ranking
Expand All @@ -185,10 +203,16 @@ def collect(self, interaction, scores_tensor):
"""
pos_len_list, user_len_list = self.get_user_pos_len_list(interaction, scores_tensor)
pos_index = self.get_pos_index(scores_tensor, pos_len_list, user_len_list)
index_list = torch.arange(1, pos_index.shape[1] + 1).to(pos_index.device)
pos_rank_sum = torch.where(pos_index, index_list, torch.zeros_like(index_list)). \
scores_matrix = self.get_score_matrix(scores_tensor, user_len_list)
desc_scores, desc_index = torch.sort(scores_matrix, dim=-1, descending=True)

# get the index of positive items in the ranking list
pos_index = (desc_index < pos_len_list.reshape(-1, 1))

avg_rank = self.average_rank(desc_scores)
pos_rank_sum = torch.where(pos_index, avg_rank, torch.zeros_like(avg_rank)). \
sum(axis=-1).reshape(-1, 1)

return pos_rank_sum

def evaluate(self, batch_matrix_list, eval_data):
Expand All @@ -205,6 +229,7 @@ def evaluate(self, batch_matrix_list, eval_data):
pos_len_list = eval_data.get_pos_len_list()
user_len_list = eval_data.get_user_len_list()
pos_rank_sum = torch.cat(batch_matrix_list, dim=0).cpu().numpy()
assert len(pos_len_list) == len(pos_rank_sum)

# get metrics
metric_dict = {}
Expand Down
57 changes: 49 additions & 8 deletions recbole/evaluator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @email : tsotfsk@outlook.com

# UPDATE
# @Time : 2020/08/12, 2020/12/9, 2020/9/16
# @Time : 2020/08/12, 2020/12/21, 2020/9/16
# @Author : Kaiyuan Li, Zhichao Feng, Xingyu Pan
# @email : tsotfsk@outlook.com, fzcbupt@gmail.com, panxy@ruc.edu.cn

Expand All @@ -23,7 +23,6 @@

# TopK Metrics #


def hit_(pos_index, pos_len):
r"""Hit_ (also known as hit ratio at :math:`N`) is a way of calculating how many 'hits' you have
in an n-sized list of ranked items.
Expand Down Expand Up @@ -129,7 +128,6 @@ def ndcg_(pos_index, pos_len):
:math:`U^{te}` is for all users in the test set.
"""

len_rank = np.full_like(pos_len, pos_index.shape[1])
idcg_len = np.where(pos_len > len_rank, len_rank, pos_len)

Expand Down Expand Up @@ -166,11 +164,54 @@ def precision_(pos_index, pos_len):


def gauc_(user_len_list, pos_len_list, pos_rank_sum):
frac = user_len_list - (pos_len_list - 1) / 2 - (1 / pos_len_list) * np.squeeze(pos_rank_sum)
neg_item_num = user_len_list - pos_len_list
user_auc = frac / neg_item_num
r"""GAUC_ (also known as Group Area Under Curve) is used to evaluate the two-class model, referring to
the area under the ROC curve grouped by user.
.. _GAUC: https://dl.acm.org/doi/10.1145/3219819.3219823
Note:
It calculates the AUC score of each user, and finally obtains GAUC by weighting the user AUC.
It is also not limited to k. Due to our padding for `scores_tensor` in `RankEvaluator` with
`-np.inf`, the padding value will influence the ranks of origin items. Therefore, we use
descending sort here and make an identity transformation to the formula of `AUC`, which is
shown in `auc_` function. For readability, we didn't do simplification in the code.
.. math::
\mathrm {GAUC} = \frac {{{M} \times {(M+N+1)} - \frac{M \times (M+1)}{2}} -
\sum\limits_{i=1}^M rank_{i}} {{M} \times {N}}
:math:`M` is the number of positive samples.
:math:`N` is the number of negative samples.
:math:`rank_i` is the descending rank of the ith positive sample.
"""
neg_len_list = user_len_list - pos_len_list

# check positive and negative samples
any_without_pos = np.any(pos_len_list == 0)
any_without_neg = np.any(neg_len_list == 0)
non_zero_idx = np.full(len(user_len_list), True, dtype=np.bool)
if any_without_pos:
logger = getLogger()
logger.warning("No positive samples in some users, "
"true positive value should be meaningless, "
"these users have been removed from GAUC calculation")
non_zero_idx *= (pos_len_list != 0)
if any_without_neg:
logger = getLogger()
logger.warning("No negative samples in some users, "
"false positive value should be meaningless, "
"these users have been removed from GAUC calculation")
non_zero_idx *= (neg_len_list != 0)
if any_without_pos or any_without_neg:
item_list = user_len_list, neg_len_list, pos_len_list, pos_rank_sum
user_len_list, neg_len_list, pos_len_list, pos_rank_sum = \
map(lambda x: x[non_zero_idx], item_list)

pair_num = (user_len_list + 1) * pos_len_list - pos_len_list * (pos_len_list + 1) / 2 - np.squeeze(pos_rank_sum)
user_auc = pair_num / (neg_len_list * pos_len_list)
result = (user_auc * pos_len_list).sum() / pos_len_list.sum()

return result


Expand All @@ -188,11 +229,11 @@ def auc_(trues, preds):
.. math::
\mathrm {AUC} = \frac{\sum\limits_{i=1}^M rank_{i}
- {{M} \times {(M+1)}}} {{M} \times {N}}
- \frac {{M} \times {(M+1)}}{2}} {{{M} \times {N}}}
:math:`M` is the number of positive samples.
:math:`N` is the number of negative samples.
:math:`rank_i` is the rank of the ith positive sample.
:math:`rank_i` is the ascending rank of the ith positive sample.
"""
fps, tps = _binary_clf_curve(trues, preds)
Expand Down
4 changes: 2 additions & 2 deletions recbole/evaluator/proxy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def build(self):
"""
evaluator_list = []
metrics_set = {metric.lower() for metric in self.metrics}
metrics_list = [metric.lower() for metric in self.metrics]
for metrics, evaluator in metric_eval_bind:
used_metrics = list(metrics_set.intersection(set(metrics.keys())))
used_metrics = [metric for metric in metrics_list if metric in metrics]
if used_metrics:
evaluator_list.append(evaluator(self.config, used_metrics))
return evaluator_list
Expand Down
74 changes: 74 additions & 0 deletions tests/metrics/test_rank_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/21
# @Author : Zhichao Feng
# @email : fzcbupt@gmail.com


import os
import sys
import unittest

sys.path.append(os.getcwd())
import numpy as np
import torch
from recbole.config import Config
from recbole.data.interaction import Interaction
from recbole.evaluator.metrics import metrics_dict
from recbole.evaluator.evaluators import RankEvaluator

parameters_dict = {
'model': 'BPR',
'eval_setting': 'RO_RS,uni100',
}


class MetricsTestCases(object):
user_len_list0 = np.array([2, 3, 5])
pos_len_list0 = np.array([1, 2, 3])
pos_rank_sum0 = np.array([1, 4, 9])

user_len_list1 = np.array([3, 6, 4])
pos_len_list1 = np.array([1, 0, 4])
pos_rank_sum1 = np.array([3, 0, 6])


class CollectTestCases(object):
interaction0 = Interaction({}, [0, 2, 3, 4], [2, 3, 4, 5])
scores_tensor0 = torch.Tensor([0.1, 0.2,
0.1, 0.1, 0.2,
0.2, 0.2, 0.2, 0.2,
0.3, 0.2, 0.1, 0.4, 0.3])


def get_metric_result(name, case=0):
func = metrics_dict[name]
return func(getattr(MetricsTestCases, f'user_len_list{case}'),
getattr(MetricsTestCases, f'pos_len_list{case}'),
getattr(MetricsTestCases, f'pos_rank_sum{case}'))


def get_collect_result(evaluator, case=0):
func = evaluator.collect
return func(getattr(CollectTestCases, f'interaction{case}'),
getattr(CollectTestCases, f'scores_tensor{case}'))


class TestRankMetrics(unittest.TestCase):
def test_gauc(self):
name = 'gauc'
self.assertEqual(get_metric_result(name, case=0), (1 * ((2 - (1 - 1) / 2 - 1 / 1) / (2 - 1)) +
2 * ((3 - (2 - 1) / 2 - 4 / 2) / (3 - 2)) +
3 * ((5 - (3 - 1) / 2 - 9 / 3) / (5 - 3)))
/ (1 + 2 + 3))
self.assertEqual(get_metric_result(name, case=1), (3 - 0 - 3 / 1) / (3 - 1))

def test_collect(self):
config = Config('BPR', 'ml-100k', config_dict=parameters_dict)
metrics = ['GAUC']
rank_evaluator = RankEvaluator(config, metrics)
self.assertEqual(get_collect_result(rank_evaluator, case=0).squeeze().cpu().numpy().tolist(),
np.array([0, (2 + 3) / 2 * 2, (1 + 2 + 3 + 4) / 4 * 3, 1 + (2 + 3) / 2 + 4 + 5]).tolist())


if __name__ == "__main__":
unittest.main()

0 comments on commit 94d5bd8

Please sign in to comment.