diff --git a/recbole/evaluator/abstract_evaluator.py b/recbole/evaluator/abstract_evaluator.py index 2570fe8ad..1ab81d476 100644 --- a/recbole/evaluator/abstract_evaluator.py +++ b/recbole/evaluator/abstract_evaluator.py @@ -4,7 +4,7 @@ # @email : tsotfsk@outlook.com # UPDATE -# @Time : 2020/10/21, 2020/12/9 +# @Time : 2020/10/21, 2020/12/18 # @Author : Kaiyuan Li, Zhichao Feng # @email : tsotfsk@outlook.com, fzcbupt@gmail.com @@ -99,7 +99,7 @@ class IndividualEvaluator(BaseEvaluator): """ def __init__(self, config, metrics): super().__init__(config, metrics) - pass + self._check_args() def sample_collect(self, true_scores, pred_scores): """It is called when evaluation sample distribution is `uniform` or `popularity`. @@ -127,3 +127,7 @@ def get_score_matrix(self, true_scores, pred_scores): scores_matrix = self.sample_collect(true_scores, pred_scores) return scores_matrix + + def _check_args(self): + if self.full: + raise NotImplementedError('full sort can\'t use IndividualEvaluator') \ No newline at end of file diff --git a/recbole/evaluator/evaluators.py b/recbole/evaluator/evaluators.py index 10a41f715..d20d73307 100644 --- a/recbole/evaluator/evaluators.py +++ b/recbole/evaluator/evaluators.py @@ -4,7 +4,7 @@ # @email : tsotfsk@outlook.com # UPDATE -# @Time : 2020/08/04, 2020/08/11, 2020/12/9 +# @Time : 2020/08/04, 2020/08/11, 2020/12/18 # @Author : Kaiyuan Li, Yupeng Hou, Zhichao Feng # @email : tsotfsk@outlook.com, houyupeng@ruc.edu.cn, fzcbupt@gmail.com @@ -158,22 +158,40 @@ def get_user_pos_len_list(self, interaction, scores_tensor): user_len_list = interaction.user_len_list return pos_len_list, user_len_list - def get_pos_index(self, scores_tensor, pos_len_list, user_len_list): - """get the index of positive items + def average_rank(self, scores): + """Get the ranking of an ordered tensor, and take the average of the ranking for positions with equal values. - Args: - scores_tensor (tensor): the tensor of model output with size of `(N, )` - pos_len_list(list): number of positive items - user_len_list(list): number of all items + Args: + scores(tensor): an ordered tensor, with size of `(N, )` + + Returns: + torch.Tensor: average_rank - Returns: - tensor: a matrix indicating whether the corresponding item is positive + Example: + >>> average_rank(tensor([[1,2,2,2,3,3,6],[2,2,2,2,4,4,5]])) + tensor([[1.0000, 3.0000, 3.0000, 3.0000, 5.5000, 5.5000, 7.0000], + [2.5000, 2.5000, 2.5000, 2.5000, 5.0000, 6.5000, 6.5000]]) - """ - scores_matrix = self.get_score_matrix(scores_tensor, user_len_list) - _, n_index = torch.sort(scores_matrix, dim=-1, descending=True) - pos_index = (n_index < pos_len_list.reshape(-1, 1)) - return pos_index + Reference: + https://github.com/scipy/scipy/blob/v0.17.1/scipy/stats/stats.py#L5262-L5352 + + """ + length, width = scores.shape + device = scores.device + true_tensor = torch.full((length, 1), True, dtype=torch.bool, device=device) + + obs = torch.cat([true_tensor, scores[:, 1:] != scores[:, :-1]], dim=1) + # bias added to dense + bias = torch.arange(0, length, device=device).repeat(width).reshape(width, -1). \ + transpose(1, 0).reshape(-1) + dense = obs.view(-1).cumsum(0) + bias + + # cumulative counts of each unique value + count = torch.where(torch.cat([obs, true_tensor], dim=1))[1] + # get average rank + avg_rank = .5 * (count[dense] + count[dense - 1] + 1).view(length, -1) + + return avg_rank def collect(self, interaction, scores_tensor): """collect the rank intermediate result of one batch, this function mainly implements ranking @@ -185,10 +203,16 @@ def collect(self, interaction, scores_tensor): """ pos_len_list, user_len_list = self.get_user_pos_len_list(interaction, scores_tensor) - pos_index = self.get_pos_index(scores_tensor, pos_len_list, user_len_list) - index_list = torch.arange(1, pos_index.shape[1] + 1).to(pos_index.device) - pos_rank_sum = torch.where(pos_index, index_list, torch.zeros_like(index_list)). \ + scores_matrix = self.get_score_matrix(scores_tensor, user_len_list) + desc_scores, desc_index = torch.sort(scores_matrix, dim=-1, descending=True) + + # get the index of positive items in the ranking list + pos_index = (desc_index < pos_len_list.reshape(-1, 1)) + + avg_rank = self.average_rank(desc_scores) + pos_rank_sum = torch.where(pos_index, avg_rank, torch.zeros_like(avg_rank)). \ sum(axis=-1).reshape(-1, 1) + return pos_rank_sum def evaluate(self, batch_matrix_list, eval_data): @@ -205,6 +229,7 @@ def evaluate(self, batch_matrix_list, eval_data): pos_len_list = eval_data.get_pos_len_list() user_len_list = eval_data.get_user_len_list() pos_rank_sum = torch.cat(batch_matrix_list, dim=0).cpu().numpy() + assert len(pos_len_list) == len(pos_rank_sum) # get metrics metric_dict = {} diff --git a/recbole/evaluator/metrics.py b/recbole/evaluator/metrics.py index cb31e1456..483fd3f21 100644 --- a/recbole/evaluator/metrics.py +++ b/recbole/evaluator/metrics.py @@ -4,7 +4,7 @@ # @email : tsotfsk@outlook.com # UPDATE -# @Time : 2020/08/12, 2020/12/9, 2020/9/16 +# @Time : 2020/08/12, 2020/12/21, 2020/9/16 # @Author : Kaiyuan Li, Zhichao Feng, Xingyu Pan # @email : tsotfsk@outlook.com, fzcbupt@gmail.com, panxy@ruc.edu.cn @@ -23,7 +23,6 @@ # TopK Metrics # - def hit_(pos_index, pos_len): r"""Hit_ (also known as hit ratio at :math:`N`) is a way of calculating how many 'hits' you have in an n-sized list of ranked items. @@ -129,7 +128,6 @@ def ndcg_(pos_index, pos_len): :math:`U^{te}` is for all users in the test set. """ - len_rank = np.full_like(pos_len, pos_index.shape[1]) idcg_len = np.where(pos_len > len_rank, len_rank, pos_len) @@ -166,11 +164,54 @@ def precision_(pos_index, pos_len): def gauc_(user_len_list, pos_len_list, pos_rank_sum): - frac = user_len_list - (pos_len_list - 1) / 2 - (1 / pos_len_list) * np.squeeze(pos_rank_sum) - neg_item_num = user_len_list - pos_len_list - user_auc = frac / neg_item_num + r"""GAUC_ (also known as Group Area Under Curve) is used to evaluate the two-class model, referring to + the area under the ROC curve grouped by user. + + .. _GAUC: https://dl.acm.org/doi/10.1145/3219819.3219823 + + Note: + It calculates the AUC score of each user, and finally obtains GAUC by weighting the user AUC. + It is also not limited to k. Due to our padding for `scores_tensor` in `RankEvaluator` with + `-np.inf`, the padding value will influence the ranks of origin items. Therefore, we use + descending sort here and make an identity transformation to the formula of `AUC`, which is + shown in `auc_` function. For readability, we didn't do simplification in the code. + + .. math:: + \mathrm {GAUC} = \frac {{{M} \times {(M+N+1)} - \frac{M \times (M+1)}{2}} - + \sum\limits_{i=1}^M rank_{i}} {{M} \times {N}} + :math:`M` is the number of positive samples. + :math:`N` is the number of negative samples. + :math:`rank_i` is the descending rank of the ith positive sample. + + """ + neg_len_list = user_len_list - pos_len_list + + # check positive and negative samples + any_without_pos = np.any(pos_len_list == 0) + any_without_neg = np.any(neg_len_list == 0) + non_zero_idx = np.full(len(user_len_list), True, dtype=np.bool) + if any_without_pos: + logger = getLogger() + logger.warning("No positive samples in some users, " + "true positive value should be meaningless, " + "these users have been removed from GAUC calculation") + non_zero_idx *= (pos_len_list != 0) + if any_without_neg: + logger = getLogger() + logger.warning("No negative samples in some users, " + "false positive value should be meaningless, " + "these users have been removed from GAUC calculation") + non_zero_idx *= (neg_len_list != 0) + if any_without_pos or any_without_neg: + item_list = user_len_list, neg_len_list, pos_len_list, pos_rank_sum + user_len_list, neg_len_list, pos_len_list, pos_rank_sum = \ + map(lambda x: x[non_zero_idx], item_list) + + pair_num = (user_len_list + 1) * pos_len_list - pos_len_list * (pos_len_list + 1) / 2 - np.squeeze(pos_rank_sum) + user_auc = pair_num / (neg_len_list * pos_len_list) result = (user_auc * pos_len_list).sum() / pos_len_list.sum() + return result @@ -188,11 +229,11 @@ def auc_(trues, preds): .. math:: \mathrm {AUC} = \frac{\sum\limits_{i=1}^M rank_{i} - - {{M} \times {(M+1)}}} {{M} \times {N}} + - \frac {{M} \times {(M+1)}}{2}} {{{M} \times {N}}} :math:`M` is the number of positive samples. :math:`N` is the number of negative samples. - :math:`rank_i` is the rank of the ith positive sample. + :math:`rank_i` is the ascending rank of the ith positive sample. """ fps, tps = _binary_clf_curve(trues, preds) diff --git a/recbole/evaluator/proxy_evaluator.py b/recbole/evaluator/proxy_evaluator.py index fc91f5d48..f5faaebd6 100644 --- a/recbole/evaluator/proxy_evaluator.py +++ b/recbole/evaluator/proxy_evaluator.py @@ -32,9 +32,9 @@ def build(self): """ evaluator_list = [] - metrics_set = {metric.lower() for metric in self.metrics} + metrics_list = [metric.lower() for metric in self.metrics] for metrics, evaluator in metric_eval_bind: - used_metrics = list(metrics_set.intersection(set(metrics.keys()))) + used_metrics = [metric for metric in metrics_list if metric in metrics] if used_metrics: evaluator_list.append(evaluator(self.config, used_metrics)) return evaluator_list diff --git a/tests/metrics/test_rank_metrics.py b/tests/metrics/test_rank_metrics.py new file mode 100644 index 000000000..c60a2c485 --- /dev/null +++ b/tests/metrics/test_rank_metrics.py @@ -0,0 +1,74 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/12/21 +# @Author : Zhichao Feng +# @email : fzcbupt@gmail.com + + +import os +import sys +import unittest + +sys.path.append(os.getcwd()) +import numpy as np +import torch +from recbole.config import Config +from recbole.data.interaction import Interaction +from recbole.evaluator.metrics import metrics_dict +from recbole.evaluator.evaluators import RankEvaluator + +parameters_dict = { + 'model': 'BPR', + 'eval_setting': 'RO_RS,uni100', +} + + +class MetricsTestCases(object): + user_len_list0 = np.array([2, 3, 5]) + pos_len_list0 = np.array([1, 2, 3]) + pos_rank_sum0 = np.array([1, 4, 9]) + + user_len_list1 = np.array([3, 6, 4]) + pos_len_list1 = np.array([1, 0, 4]) + pos_rank_sum1 = np.array([3, 0, 6]) + + +class CollectTestCases(object): + interaction0 = Interaction({}, [0, 2, 3, 4], [2, 3, 4, 5]) + scores_tensor0 = torch.Tensor([0.1, 0.2, + 0.1, 0.1, 0.2, + 0.2, 0.2, 0.2, 0.2, + 0.3, 0.2, 0.1, 0.4, 0.3]) + + +def get_metric_result(name, case=0): + func = metrics_dict[name] + return func(getattr(MetricsTestCases, f'user_len_list{case}'), + getattr(MetricsTestCases, f'pos_len_list{case}'), + getattr(MetricsTestCases, f'pos_rank_sum{case}')) + + +def get_collect_result(evaluator, case=0): + func = evaluator.collect + return func(getattr(CollectTestCases, f'interaction{case}'), + getattr(CollectTestCases, f'scores_tensor{case}')) + + +class TestRankMetrics(unittest.TestCase): + def test_gauc(self): + name = 'gauc' + self.assertEqual(get_metric_result(name, case=0), (1 * ((2 - (1 - 1) / 2 - 1 / 1) / (2 - 1)) + + 2 * ((3 - (2 - 1) / 2 - 4 / 2) / (3 - 2)) + + 3 * ((5 - (3 - 1) / 2 - 9 / 3) / (5 - 3))) + / (1 + 2 + 3)) + self.assertEqual(get_metric_result(name, case=1), (3 - 0 - 3 / 1) / (3 - 1)) + + def test_collect(self): + config = Config('BPR', 'ml-100k', config_dict=parameters_dict) + metrics = ['GAUC'] + rank_evaluator = RankEvaluator(config, metrics) + self.assertEqual(get_collect_result(rank_evaluator, case=0).squeeze().cpu().numpy().tolist(), + np.array([0, (2 + 3) / 2 * 2, (1 + 2 + 3 + 4) / 4 * 3, 1 + (2 + 3) / 2 + 4 + 5]).tolist()) + + +if __name__ == "__main__": + unittest.main()