Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: bugs in evaluator #590

Merged
merged 27 commits into from
Dec 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0772227
FEA: add config['benchmark_filename'] to load pre-split dataset; incr…
chenyushuo Dec 18, 2020
3769b2d
FIX: Increased the robustness of GeneralFullDataLoader, which can han…
chenyushuo Dec 18, 2020
dd157a5
FIX: can't raise error in IndividualEvaluator
guijiql Dec 18, 2020
9ced718
FIX: metrics disorder
guijiql Dec 18, 2020
c7cbd34
FIX: GAUC calculation error
guijiql Dec 18, 2020
5c1c147
FIX: rename & comment format
guijiql Dec 18, 2020
2dcac28
REVERT: revert modify in data.utils
chenyushuo Dec 18, 2020
e8db062
Merge pull request #588 from chenyushuo/0.2.x
2017pxy Dec 18, 2020
fd86870
update notes
guijiql Dec 19, 2020
e84aeb7
FEA: Increased the robustness of trainer.evaluate
chenyushuo Dec 20, 2020
50bf9e8
FIX: bug fix in GeneralFullDataLoader.
chenyushuo Dec 20, 2020
27dad3f
Merge pull request #596 from chenyushuo/0.2.x
2017pxy Dec 20, 2020
4b4b9a8
FIX: optimize update_attentive_A function in KGAT
Dec 21, 2020
da3972c
Merge pull request #597 from ShanleiMu/0.2.x
chenyushuo Dec 21, 2020
83f514e
FIX: can't raise error in IndividualEvaluator
guijiql Dec 18, 2020
10243bb
FIX: metrics disorder
guijiql Dec 18, 2020
ab95863
FIX: GAUC calculation error
guijiql Dec 18, 2020
48f1078
FIX: rename & comment format
guijiql Dec 18, 2020
eb84ef3
FEA: add parameters check in gauc
guijiql Dec 21, 2020
aed05ee
FEA: add GAUC check & GAUC test
guijiql Dec 21, 2020
d6ea8e2
update metrics.py
guijiql Dec 22, 2020
02b3d1c
update metrics.py
guijiql Dec 22, 2020
3620841
Update evaluators.py
guijiql Dec 22, 2020
bd40a3a
FEA: add ranking metric test
guijiql Dec 22, 2020
b7c56bb
Merge branch '0.2.x' of https://github.com/guijiql/RecBole into 0.2.x
guijiql Dec 22, 2020
6e9a9c6
FIX: rename bool variable in GAUC & remove keys in build
guijiql Dec 22, 2020
634bad6
FEA: add RankEvaluator collect test
guijiql Dec 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions recbole/data/dataloader/general_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,11 @@ def __init__(self, config, dataset, sampler, neg_sample_args,

dataset.sort(by=uid_field, ascending=True)
last_uid = None
positive_item = None
positive_item = set()
uid2used_item = sampler.used_ids
for uid, iid in zip(dataset.inter_feat[uid_field].numpy(), dataset.inter_feat[iid_field].numpy()):
if uid != last_uid:
if last_uid is not None:
self._set_user_property(last_uid, uid2used_item[last_uid], positive_item)
self._set_user_property(last_uid, uid2used_item[last_uid], positive_item)
last_uid = uid
self.uid_list.append(uid)
positive_item = set()
Expand All @@ -238,6 +237,8 @@ def __init__(self, config, dataset, sampler, neg_sample_args,
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)

def _set_user_property(self, uid, used_item, positive_item):
if uid is None:
return
history_item = used_item - positive_item
positive_item_num = len(positive_item)
self.uid2items_num[uid] = positive_item_num
Expand All @@ -260,17 +261,16 @@ def _shuffle(self):
self.logger.warnning('GeneralFullDataLoader can\'t shuffle')

def _next_batch_data(self):
index = slice(self.pr, self.pr + self.step)
user_df = self.user_df[index]
pos_len_list = self.uid2items_num[self.uid_list[index]]
user_len_list = np.full(len(user_df), self.item_num)
user_df.set_additional_info(pos_len_list, user_len_list)
user_df = self.user_df[self.pr: self.pr + self.step]
cur_data = self._neg_sampling(user_df)
self.pr += self.step
return cur_data

def _neg_sampling(self, user_df):
uid_list = list(user_df[self.dataset.uid_field])
pos_len_list = self.uid2items_num[uid_list]
user_len_list = np.full(len(uid_list), self.item_num)
user_df.set_additional_info(pos_len_list, user_len_list)

history_item = self.uid2history_item[uid_list]
history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)])
Expand Down
5 changes: 5 additions & 0 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,6 +1340,11 @@ def build(self, eval_setting):
Returns:
list: List of builded :class:`Dataset`.
"""
if self.benchmark_filename_list is not None:
cumsum = list(np.cumsum(self.file_size_list))
datasets = [self.copy(self.inter_feat[start: end]) for start, end in zip([0] + cumsum[:-1], cumsum)]
return datasets

ordering_args = eval_setting.ordering_args
if ordering_args['strategy'] == 'shuffle':
self.shuffle()
Expand Down
2 changes: 1 addition & 1 deletion recbole/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def data_preparation(config, dataset, save=False):
getattr(es, es_str[1])()
if 'sampler' not in locals():
sampler = Sampler(phases, builded_datasets, es.neg_sample_args['distribution'])
sampler.set_distribution(es.neg_sample_args['distribution'])
sampler.set_distribution(es.neg_sample_args['distribution'])
kwargs['sampler'] = [sampler.set_phase('valid'), sampler.set_phase('test')]
kwargs['neg_sample_args'] = copy.deepcopy(es.neg_sample_args)
valid_data, test_data = dataloader_construct(
Expand Down
8 changes: 6 additions & 2 deletions recbole/evaluator/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @email : tsotfsk@outlook.com

# UPDATE
# @Time : 2020/10/21, 2020/12/9
# @Time : 2020/10/21, 2020/12/18
# @Author : Kaiyuan Li, Zhichao Feng
# @email : tsotfsk@outlook.com, fzcbupt@gmail.com

Expand Down Expand Up @@ -99,7 +99,7 @@ class IndividualEvaluator(BaseEvaluator):
"""
def __init__(self, config, metrics):
super().__init__(config, metrics)
pass
self._check_args()

def sample_collect(self, true_scores, pred_scores):
"""It is called when evaluation sample distribution is `uniform` or `popularity`.
Expand Down Expand Up @@ -127,3 +127,7 @@ def get_score_matrix(self, true_scores, pred_scores):
scores_matrix = self.sample_collect(true_scores, pred_scores)

return scores_matrix

def _check_args(self):
if self.full:
raise NotImplementedError('full sort can\'t use IndividualEvaluator')
59 changes: 42 additions & 17 deletions recbole/evaluator/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @email : tsotfsk@outlook.com

# UPDATE
# @Time : 2020/08/04, 2020/08/11, 2020/12/9
# @Time : 2020/08/04, 2020/08/11, 2020/12/18
# @Author : Kaiyuan Li, Yupeng Hou, Zhichao Feng
# @email : tsotfsk@outlook.com, houyupeng@ruc.edu.cn, fzcbupt@gmail.com

Expand Down Expand Up @@ -158,22 +158,40 @@ def get_user_pos_len_list(self, interaction, scores_tensor):
user_len_list = interaction.user_len_list
return pos_len_list, user_len_list

def get_pos_index(self, scores_tensor, pos_len_list, user_len_list):
"""get the index of positive items
def average_rank(self, scores):
"""Get the ranking of an ordered tensor, and take the average of the ranking for positions with equal values.

Args:
scores_tensor (tensor): the tensor of model output with size of `(N, )`
pos_len_list(list): number of positive items
user_len_list(list): number of all items
Args:
scores(tensor): an ordered tensor, with size of `(N, )`

Returns:
torch.Tensor: average_rank

Returns:
tensor: a matrix indicating whether the corresponding item is positive
Example:
>>> average_rank(tensor([[1,2,2,2,3,3,6],[2,2,2,2,4,4,5]]))
tensor([[1.0000, 3.0000, 3.0000, 3.0000, 5.5000, 5.5000, 7.0000],
[2.5000, 2.5000, 2.5000, 2.5000, 5.0000, 6.5000, 6.5000]])

"""
scores_matrix = self.get_score_matrix(scores_tensor, user_len_list)
_, n_index = torch.sort(scores_matrix, dim=-1, descending=True)
pos_index = (n_index < pos_len_list.reshape(-1, 1))
return pos_index
Reference:
https://github.com/scipy/scipy/blob/v0.17.1/scipy/stats/stats.py#L5262-L5352

"""
length, width = scores.shape
device = scores.device
true_tensor = torch.full((length, 1), True, dtype=torch.bool, device=device)

obs = torch.cat([true_tensor, scores[:, 1:] != scores[:, :-1]], dim=1)
# bias added to dense
bias = torch.arange(0, length, device=device).repeat(width).reshape(width, -1). \
transpose(1, 0).reshape(-1)
dense = obs.view(-1).cumsum(0) + bias

# cumulative counts of each unique value
count = torch.where(torch.cat([obs, true_tensor], dim=1))[1]
# get average rank
avg_rank = .5 * (count[dense] + count[dense - 1] + 1).view(length, -1)

return avg_rank

def collect(self, interaction, scores_tensor):
"""collect the rank intermediate result of one batch, this function mainly implements ranking
Expand All @@ -185,10 +203,16 @@ def collect(self, interaction, scores_tensor):

"""
pos_len_list, user_len_list = self.get_user_pos_len_list(interaction, scores_tensor)
pos_index = self.get_pos_index(scores_tensor, pos_len_list, user_len_list)
index_list = torch.arange(1, pos_index.shape[1] + 1).to(pos_index.device)
pos_rank_sum = torch.where(pos_index, index_list, torch.zeros_like(index_list)). \
scores_matrix = self.get_score_matrix(scores_tensor, user_len_list)
desc_scores, desc_index = torch.sort(scores_matrix, dim=-1, descending=True)

# get the index of positive items in the ranking list
pos_index = (desc_index < pos_len_list.reshape(-1, 1))

avg_rank = self.average_rank(desc_scores)
pos_rank_sum = torch.where(pos_index, avg_rank, torch.zeros_like(avg_rank)). \
sum(axis=-1).reshape(-1, 1)

return pos_rank_sum

def evaluate(self, batch_matrix_list, eval_data):
Expand All @@ -205,6 +229,7 @@ def evaluate(self, batch_matrix_list, eval_data):
pos_len_list = eval_data.get_pos_len_list()
user_len_list = eval_data.get_user_len_list()
pos_rank_sum = torch.cat(batch_matrix_list, dim=0).cpu().numpy()
assert len(pos_len_list) == len(pos_rank_sum)

# get metrics
metric_dict = {}
Expand Down
57 changes: 49 additions & 8 deletions recbole/evaluator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @email : tsotfsk@outlook.com

# UPDATE
# @Time : 2020/08/12, 2020/12/9, 2020/9/16
# @Time : 2020/08/12, 2020/12/21, 2020/9/16
# @Author : Kaiyuan Li, Zhichao Feng, Xingyu Pan
# @email : tsotfsk@outlook.com, fzcbupt@gmail.com, panxy@ruc.edu.cn

Expand All @@ -23,7 +23,6 @@

# TopK Metrics #


def hit_(pos_index, pos_len):
r"""Hit_ (also known as hit ratio at :math:`N`) is a way of calculating how many 'hits' you have
in an n-sized list of ranked items.
Expand Down Expand Up @@ -129,7 +128,6 @@ def ndcg_(pos_index, pos_len):
:math:`U^{te}` is for all users in the test set.

"""

len_rank = np.full_like(pos_len, pos_index.shape[1])
idcg_len = np.where(pos_len > len_rank, len_rank, pos_len)

Expand Down Expand Up @@ -166,11 +164,54 @@ def precision_(pos_index, pos_len):


def gauc_(user_len_list, pos_len_list, pos_rank_sum):
frac = user_len_list - (pos_len_list - 1) / 2 - (1 / pos_len_list) * np.squeeze(pos_rank_sum)
neg_item_num = user_len_list - pos_len_list
user_auc = frac / neg_item_num
r"""GAUC_ (also known as Group Area Under Curve) is used to evaluate the two-class model, referring to
the area under the ROC curve grouped by user.

.. _GAUC: https://dl.acm.org/doi/10.1145/3219819.3219823

Note:
It calculates the AUC score of each user, and finally obtains GAUC by weighting the user AUC.
It is also not limited to k. Due to our padding for `scores_tensor` in `RankEvaluator` with
`-np.inf`, the padding value will influence the ranks of origin items. Therefore, we use
descending sort here and make an identity transformation to the formula of `AUC`, which is
shown in `auc_` function. For readability, we didn't do simplification in the code.

.. math::
\mathrm {GAUC} = \frac {{{M} \times {(M+N+1)} - \frac{M \times (M+1)}{2}} -
\sum\limits_{i=1}^M rank_{i}} {{M} \times {N}}

:math:`M` is the number of positive samples.
:math:`N` is the number of negative samples.
:math:`rank_i` is the descending rank of the ith positive sample.

"""
neg_len_list = user_len_list - pos_len_list

# check positive and negative samples
any_without_pos = np.any(pos_len_list == 0)
any_without_neg = np.any(neg_len_list == 0)
non_zero_idx = np.full(len(user_len_list), True, dtype=np.bool)
if any_without_pos:
logger = getLogger()
logger.warning("No positive samples in some users, "
"true positive value should be meaningless, "
"these users have been removed from GAUC calculation")
non_zero_idx *= (pos_len_list != 0)
if any_without_neg:
logger = getLogger()
logger.warning("No negative samples in some users, "
"false positive value should be meaningless, "
"these users have been removed from GAUC calculation")
non_zero_idx *= (neg_len_list != 0)
if any_without_pos or any_without_neg:
item_list = user_len_list, neg_len_list, pos_len_list, pos_rank_sum
user_len_list, neg_len_list, pos_len_list, pos_rank_sum = \
map(lambda x: x[non_zero_idx], item_list)

pair_num = (user_len_list + 1) * pos_len_list - pos_len_list * (pos_len_list + 1) / 2 - np.squeeze(pos_rank_sum)
user_auc = pair_num / (neg_len_list * pos_len_list)
result = (user_auc * pos_len_list).sum() / pos_len_list.sum()

return result


Expand All @@ -188,11 +229,11 @@ def auc_(trues, preds):

.. math::
\mathrm {AUC} = \frac{\sum\limits_{i=1}^M rank_{i}
- {{M} \times {(M+1)}}} {{M} \times {N}}
- \frac {{M} \times {(M+1)}}{2}} {{{M} \times {N}}}

:math:`M` is the number of positive samples.
:math:`N` is the number of negative samples.
:math:`rank_i` is the rank of the ith positive sample.
:math:`rank_i` is the ascending rank of the ith positive sample.

"""
fps, tps = _binary_clf_curve(trues, preds)
Expand Down
4 changes: 2 additions & 2 deletions recbole/evaluator/proxy_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def build(self):
"""
evaluator_list = []
metrics_set = {metric.lower() for metric in self.metrics}
metrics_list = [metric.lower() for metric in self.metrics]
for metrics, evaluator in metric_eval_bind:
used_metrics = list(metrics_set.intersection(set(metrics.keys())))
used_metrics = [metric for metric in metrics_list if metric in metrics]
if used_metrics:
evaluator_list.append(evaluator(self.config, used_metrics))
return evaluator_list
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/knowledge_aware_recommender/kgat.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def update_attentive_A(self):
# Current PyTorch version does not support softmax on SparseCUDA, temporarily move to CPU to calculate softmax
A_in = torch.sparse.FloatTensor(indices, kg_score, self.matrix_size).cpu()
A_in = torch.sparse.softmax(A_in, dim=1).to(self.device)
self.A_in = copy.copy(A_in)
self.A_in = A_in

def predict(self, interaction):
user = interaction[self.USER_ID]
Expand Down
7 changes: 6 additions & 1 deletion recbole/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None):
Returns:
dict: eval result, key is the eval metric and value in the corresponding metric value
"""
if not eval_data:
return

if load_best_model:
if model_file:
checkpoint_file = model_file
Expand Down Expand Up @@ -441,7 +444,9 @@ def _train_epoch(self, train_data, epoch_idx, loss_func=None):
kg_total_loss = super()._train_epoch(train_data, epoch_idx, self.model.calculate_kg_loss)

# update A
self.model.update_attentive_A()
self.model.eval()
with torch.no_grad():
self.model.update_attentive_A()

return rs_total_loss, kg_total_loss

Expand Down
74 changes: 74 additions & 0 deletions tests/metrics/test_rank_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/12/21
# @Author : Zhichao Feng
# @email : fzcbupt@gmail.com


import os
import sys
import unittest

sys.path.append(os.getcwd())
import numpy as np
import torch
from recbole.config import Config
from recbole.data.interaction import Interaction
from recbole.evaluator.metrics import metrics_dict
from recbole.evaluator.evaluators import RankEvaluator

parameters_dict = {
'model': 'BPR',
'eval_setting': 'RO_RS,uni100',
}


class MetricsTestCases(object):
user_len_list0 = np.array([2, 3, 5])
pos_len_list0 = np.array([1, 2, 3])
pos_rank_sum0 = np.array([1, 4, 9])

user_len_list1 = np.array([3, 6, 4])
pos_len_list1 = np.array([1, 0, 4])
pos_rank_sum1 = np.array([3, 0, 6])


class CollectTestCases(object):
interaction0 = Interaction({}, [0, 2, 3, 4], [2, 3, 4, 5])
scores_tensor0 = torch.Tensor([0.1, 0.2,
0.1, 0.1, 0.2,
0.2, 0.2, 0.2, 0.2,
0.3, 0.2, 0.1, 0.4, 0.3])


def get_metric_result(name, case=0):
func = metrics_dict[name]
return func(getattr(MetricsTestCases, f'user_len_list{case}'),
getattr(MetricsTestCases, f'pos_len_list{case}'),
getattr(MetricsTestCases, f'pos_rank_sum{case}'))


def get_collect_result(evaluator, case=0):
func = evaluator.collect
return func(getattr(CollectTestCases, f'interaction{case}'),
getattr(CollectTestCases, f'scores_tensor{case}'))


class TestRankMetrics(unittest.TestCase):
def test_gauc(self):
name = 'gauc'
self.assertEqual(get_metric_result(name, case=0), (1 * ((2 - (1 - 1) / 2 - 1 / 1) / (2 - 1)) +
2 * ((3 - (2 - 1) / 2 - 4 / 2) / (3 - 2)) +
3 * ((5 - (3 - 1) / 2 - 9 / 3) / (5 - 3)))
/ (1 + 2 + 3))
self.assertEqual(get_metric_result(name, case=1), (3 - 0 - 3 / 1) / (3 - 1))

def test_collect(self):
config = Config('BPR', 'ml-100k', config_dict=parameters_dict)
metrics = ['GAUC']
rank_evaluator = RankEvaluator(config, metrics)
self.assertEqual(get_collect_result(rank_evaluator, case=0).squeeze().cpu().numpy().tolist(),
np.array([0, (2 + 3) / 2 * 2, (1 + 2 + 3 + 4) / 4 * 3, 1 + (2 + 3) / 2 + 4 + 5]).tolist())


if __name__ == "__main__":
unittest.main()