diff --git a/recbole/data/dataloader/abstract_dataloader.py b/recbole/data/dataloader/abstract_dataloader.py index 34e5788d5..d2ee8ee4c 100644 --- a/recbole/data/dataloader/abstract_dataloader.py +++ b/recbole/data/dataloader/abstract_dataloader.py @@ -117,17 +117,18 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args): self.iid_field = dataset.iid_field self.dl_format = dl_format self.neg_sample_args = neg_sample_args + self.times = 1 if self.neg_sample_args['strategy'] == 'by': - self.neg_sample_by = self.neg_sample_args['by'] + self.neg_sample_num = self.neg_sample_args['by'] if self.dl_format == InputType.POINTWISE: - self.times = 1 + self.neg_sample_by + self.times = 1 + self.neg_sample_num self.sampling_func = self._neg_sample_by_point_wise_sampling self.label_field = config['LABEL_FIELD'] dataset.set_field_property(self.label_field, FeatureType.FLOAT, FeatureSource.INTERACTION, 1) elif self.dl_format == InputType.PAIRWISE: - self.times = self.neg_sample_by + self.times = self.neg_sample_num self.sampling_func = self._neg_sample_by_pair_wise_sampling self.neg_prefix = config['NEG_PREFIX'] @@ -147,7 +148,7 @@ def _neg_sampling(self, inter_feat): if self.neg_sample_args['strategy'] == 'by': user_ids = inter_feat[self.uid_field] item_ids = inter_feat[self.iid_field] - neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_by) + neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num) return self.sampling_func(inter_feat, neg_item_ids) else: return inter_feat diff --git a/recbole/data/dataloader/general_dataloader.py b/recbole/data/dataloader/general_dataloader.py index 7e605dba0..04d5e3c49 100644 --- a/recbole/data/dataloader/general_dataloader.py +++ b/recbole/data/dataloader/general_dataloader.py @@ -120,30 +120,24 @@ def _shuffle(self): def _next_batch_data(self): uid_list = self.uid_list[self.pr:self.pr + self.step] data_list = [] - for uid in uid_list: + idx_list = [] + positive_u = [] + positive_i = torch.tensor([], dtype=torch.int64) + + for idx, uid in enumerate(uid_list): index = self.uid2index[uid] data_list.append(self._neg_sampling(self.dataset[index])) - cur_data = cat_interactions(data_list) - if self.neg_sample_args['strategy'] == 'by': - pos_len_list = self.uid2items_num[uid_list] - user_len_list = pos_len_list * self.times - cur_data.set_additional_info(list(pos_len_list), list(user_len_list)) - self.pr += self.step - return cur_data + idx_list += [idx for i in range(self.uid2items_num[uid] * self.times)] + positive_u += [idx for i in range(self.uid2items_num[uid])] + positive_i = torch.cat((positive_i, self.dataset[index][self.iid_field]), 0) - def get_pos_len_list(self): - """ - Returns: - numpy.ndarray: Number of positive item for each user in a training/evaluating epoch. - """ - return self.uid2items_num[self.uid_list] + cur_data = cat_interactions(data_list) + idx_list = torch.from_numpy(np.array(idx_list)) + positive_u = torch.from_numpy(np.array(positive_u)) - def get_user_len_list(self): - """ - Returns: - numpy.ndarray: Number of all item for each user in a training/evaluating epoch. - """ - return self.uid2items_num[self.uid_list] * self.times + self.pr += self.step + + return cur_data, idx_list, positive_u, positive_i class FullSortEvalDataLoader(AbstractDataLoader): @@ -167,8 +161,7 @@ def __init__(self, config, dataset, sampler, shuffle=False): user_num = dataset.user_num self.uid_list = [] self.uid2items_num = np.zeros(user_num, dtype=np.int64) - self.uid2swap_idx = np.array([None] * user_num) - self.uid2rev_swap_idx = np.array([None] * user_num) + self.uid2positive_item = np.array([None] * user_num) self.uid2history_item = np.array([None] * user_num) dataset.sort(by=self.uid_field, ascending=True) @@ -192,11 +185,8 @@ def _set_user_property(self, uid, used_item, positive_item): if uid is None: return history_item = used_item - positive_item - positive_item_num = len(positive_item) - self.uid2items_num[uid] = positive_item_num - swap_idx = torch.tensor(sorted(set(range(positive_item_num)) ^ positive_item)) - self.uid2swap_idx[uid] = swap_idx - self.uid2rev_swap_idx[uid] = swap_idx.flip(0) + self.uid2positive_item[uid] = torch.tensor(list(positive_item), dtype=torch.int64) + self.uid2items_num[uid] = len(positive_item) self.uid2history_item[uid] = torch.tensor(list(history_item), dtype=torch.int64) def _init_batch_size_and_step(self): @@ -223,51 +213,24 @@ def _shuffle(self): def _next_batch_data(self): if not self.is_sequential: user_df = self.user_df[self.pr:self.pr + self.step] - uid_list = list(user_df[self.dataset.uid_field]) - pos_len_list = self.uid2items_num[uid_list] - user_len_list = np.full(len(uid_list), self.dataset.item_num) - user_df.set_additional_info(pos_len_list, user_len_list) - + uid_list = list(user_df[self.uid_field]) + history_item = self.uid2history_item[uid_list] - history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) - history_col = torch.cat(list(history_item)) + positive_item = self.uid2positive_item[uid_list] + + history_u = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) + history_i = torch.cat(list(history_item)) - swap_idx = self.uid2swap_idx[uid_list] - rev_swap_idx = self.uid2rev_swap_idx[uid_list] - swap_row = torch.cat([torch.full_like(swap, i) for i, swap in enumerate(swap_idx)]) - swap_col_after = torch.cat(list(swap_idx)) - swap_col_before = torch.cat(list(rev_swap_idx)) + positive_u = torch.cat([torch.full_like(pos_iid, i) for i, pos_iid in enumerate(positive_item)]) + positive_i = torch.cat(list(positive_item)) self.pr += self.step - return user_df, (history_row, history_col), swap_row, swap_col_after, swap_col_before + return user_df, (history_u, history_i), positive_u, positive_i else: interaction = self.dataset[self.pr:self.pr + self.step] inter_num = len(interaction) - pos_len_list = np.ones(inter_num, dtype=np.int64) - user_len_list = np.full(inter_num, self.dataset.item_num) - interaction.set_additional_info(pos_len_list, user_len_list) - scores_row = torch.arange(inter_num).repeat(2) - padding_idx = torch.zeros(inter_num, dtype=torch.int64) - positive_idx = interaction[self.iid_field] - scores_col_after = torch.cat((padding_idx, positive_idx)) - scores_col_before = torch.cat((positive_idx, padding_idx)) + positive_u = torch.arange(inter_num) + positive_i = interaction[self.iid_field] self.pr += self.step - return interaction, None, scores_row, scores_col_after, scores_col_before - - def get_pos_len_list(self): - """ - Returns: - numpy.ndarray: Number of positive item for each user in a training/evaluating epoch. - """ - if not self.is_sequential: - return self.uid2items_num[self.uid_list] - else: - return np.ones(self.pr_end, dtype=np.int64) - - def get_user_len_list(self): - """ - Returns: - numpy.ndarray: Number of all item for each user in a training/evaluating epoch. - """ - return np.full(self.pr_end, self.dataset.item_num) + return interaction, None, positive_u, positive_i diff --git a/recbole/evaluator/abstract_evaluator.py b/recbole/evaluator/abstract_evaluator.py deleted file mode 100644 index f49ef7b49..000000000 --- a/recbole/evaluator/abstract_evaluator.py +++ /dev/null @@ -1,136 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/10/21 -# @Author : Kaiyuan Li -# @email : tsotfsk@outlook.com - -# UPDATE -# @Time : 2020/10/21, 2020/12/18, 2021/7/1 -# @Author : Kaiyuan Li, Zhichao Feng, Xingyu Pan -# @email : tsotfsk@outlook.com, fzcbupt@gmail.com, xy_oan@foxmail.com - -""" -recbole.evaluator.abstract_evaluator -##################################### -""" - -import numpy as np -import torch -from torch.nn.utils.rnn import pad_sequence - - -class BaseEvaluator(object): - """:class:`BaseEvaluator` is an object which supports - the evaluation of the model. It is called by :class:`Trainer`. - - Note: - If you want to inherit this class and implement your own evaluator class, - you must implement the following functions. - - Args: - config (Config): The config of evaluator. - - """ - - def __init__(self, config, metrics): - self.metrics = metrics - self.full = ('full' in config['eval_args']['mode']) - self.precision = config['metric_decimal_place'] - - def collect(self, *args): - """get the intermediate results for each batch, it is called at the end of each batch""" - raise NotImplementedError - - def evaluate(self, *args): - """calculate the metrics of all batches, it is called at the end of each epoch""" - raise NotImplementedError - - def _calculate_metrics(self, *args): - """ to calculate the metrics""" - raise NotImplementedError - - -class GroupedEvaluator(BaseEvaluator): - """:class:`GroupedEvaluator` is an object which supports the evaluation of the model. - - Note: - If you want to implement a new group-based metric, - you may need to inherit this class - - """ - - def __init__(self, config, metrics): - super().__init__(config, metrics) - pass - - def sample_collect(self, scores_tensor, user_len_list): - """padding scores_tensor. It is called when evaluation sample distribution is `uniform` or `popularity`. - - """ - scores_list = torch.split(scores_tensor, user_len_list, dim=0) - padding_score = pad_sequence(scores_list, batch_first=True, padding_value=-np.inf) # n_users x items - return padding_score - - def full_sort_collect(self, scores_tensor, user_len_list): - """it is called when evaluation sample distribution is `full`. - - """ - return scores_tensor.view(len(user_len_list), -1) - - def get_score_matrix(self, scores_tensor, user_len_list): - """get score matrix. - - Args: - scores_tensor (tensor): the tensor of model output with size of `(N, )` - user_len_list(list): number of all items - - """ - if self.full: - scores_matrix = self.full_sort_collect(scores_tensor, user_len_list) - else: - scores_matrix = self.sample_collect(scores_tensor, user_len_list) - return scores_matrix - - -class IndividualEvaluator(BaseEvaluator): - """:class:`IndividualEvaluator` is an object which supports the evaluation of the model. - - Note: - If you want to implement a new non-group-based metric, - you may need to inherit this class - - """ - - def __init__(self, config, metrics): - super().__init__(config, metrics) - self._check_args() - - def sample_collect(self, true_scores, pred_scores): - """It is called when evaluation sample distribution is `uniform` or `popularity`. - - """ - return torch.stack((true_scores, pred_scores.detach()), dim=1) - - def full_sort_collect(self, true_scores, pred_scores): - """it is called when evaluation sample distribution is `full`. - - """ - raise NotImplementedError('full sort can\'t use IndividualEvaluator') - - def get_score_matrix(self, true_scores, pred_scores): - """get score matrix - - Args: - true_scores (tensor): the label of predicted items - pred_scores (tensor): the tensor of model output with a size of `(N, )` - - """ - if self.full: - scores_matrix = self.full_sort_collect(true_scores, pred_scores) - else: - scores_matrix = self.sample_collect(true_scores, pred_scores) - - return scores_matrix - - def _check_args(self): - if self.full: - raise NotImplementedError('full sort can\'t use IndividualEvaluator') diff --git a/recbole/evaluator/base_metric.py b/recbole/evaluator/base_metric.py index f5e2369e1..2c5e5260b 100644 --- a/recbole/evaluator/base_metric.py +++ b/recbole/evaluator/base_metric.py @@ -1,12 +1,11 @@ -# -*- encoding: utf-8 -*- -# @Time : 2020/10/21 -# @Author : Kaiyuan Li -# @email : tsotfsk@outlook.com +# @Time : 2020/10/21 +# @Author : Kaiyuan Li +# @email : tsotfsk@outlook.com # UPDATE -# @Time : 2020/10/21, 2021/6/25 -# @Author : Kaiyuan Li, Zhichao Feng -# @email : tsotfsk@outlook.com, fzcbupt@gmail.com +# @Time : 2020/10/21, 2021/7/18 +# @Author : Kaiyuan Li, Zhichao Feng +# @email : tsotfsk@outlook.com, fzcbupt@gmail.com """ recbole.evaluator.abstract_metric @@ -26,16 +25,13 @@ class TopkMetric(object): """ def __init__(self, config): self.topk = config['topk'] - self.indice = [max(self.topk), 1, 1] self.decimal_place = config['metric_decimal_place'] def used_info(self, dataobject): """get the bool matrix indicating whether the corresponding item is positive""" rec_mat = dataobject.get('rec.topk') - topk_idx, shapes, pos_len_list = torch.split(rec_mat, self.indice, dim=1) - pos_idx_matrix = (topk_idx >= (shapes - pos_len_list).reshape(-1, 1)) - - return pos_idx_matrix.numpy(), pos_len_list.squeeze().numpy() + topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1) + return rec_mat.to(torch.bool).numpy(), pos_len_list.squeeze().numpy() def topk_result(self, metric, value): """match the metric value to the `k` and put them in `dictionary` form""" diff --git a/recbole/evaluator/collector.py b/recbole/evaluator/collector.py index 075e7a6c8..44b7bf186 100644 --- a/recbole/evaluator/collector.py +++ b/recbole/evaluator/collector.py @@ -2,6 +2,11 @@ # @Author : Zihan Lin # @Email : zhlin@ruc.edu.cn +# UPDATE +# @Time : 2021/7/18 +# @Author : Zhichao Feng +# @email : fzcbupt@gmail.com + """ recbole.evaluator.collector ################################################ @@ -9,8 +14,6 @@ from recbole.evaluator.register import Register import torch -from torch.nn.utils.rnn import pad_sequence -import numpy as np import copy class DataStruct(object): @@ -67,7 +70,7 @@ def __init__(self, config): self.register = Register(config) self.full = ('full' in config['eval_args']['mode']) self.topk = self.config['topk'] - self.topk_idx = None + self.device = self.config['device'] def data_collect(self, train_data): """ Collect the evaluation resource from training data. @@ -86,24 +89,6 @@ def data_collect(self, train_data): if self.register.need('data.count_users'): self.data_struct.set('data.count_items', train_data.dataset.user_counter) - def _get_score_matrix(self, scores_tensor, user_len_list): - """get score matrix. - - Args: - scores_tensor (tensor): the tensor of model output with size of `(N, )` - user_len_list(list): number of all items - - """ - if self.full: - scores_matrix = scores_tensor.reshape(len(user_len_list), -1) - else: - scores_list = torch.split(scores_tensor, user_len_list, dim=0) - if scores_tensor.dtype is torch.FloatTensor: - scores_matrix = pad_sequence(scores_list, batch_first=True, padding_value=-np.inf) # n_users x items - else: # padding the id tensor - scores_matrix = pad_sequence(scores_list, batch_first=True, padding_value=-1) # n_users x items - return scores_matrix - def _average_rank(self, scores): """Get the ranking of an ordered tensor, and take the average of the ranking for positions with equal values. @@ -123,12 +108,11 @@ def _average_rank(self, scores): """ length, width = scores.shape - device = scores.device - true_tensor = torch.full((length, 1), True, dtype=torch.bool, device=device) + true_tensor = torch.full((length, 1), True, dtype=torch.bool, device=self.device) obs = torch.cat([true_tensor, scores[:, 1:] != scores[:, :-1]], dim=1) # bias added to dense - bias = torch.arange(0, length, device=device).repeat(width).reshape(width, -1). \ + bias = torch.arange(0, length, device=self.device).repeat(width).reshape(width, -1). \ transpose(1, 0).reshape(-1) dense = obs.view(-1).cumsum(0) + bias @@ -139,50 +123,45 @@ def _average_rank(self, scores): return avg_rank - def eval_batch_collect(self, scores_tensor: torch.Tensor, interaction): + def eval_batch_collect(self, scores_tensor: torch.Tensor, interaction, positive_u: torch.Tensor, positive_i: torch.Tensor): """ Collect the evaluation resource from batched eval data and batched model output. Args: scores_tensor (Torch.Tensor): the output tensor of model with the shape of `(N, )` interaction(Interaction): batched eval data. + positive_u(Torch.Tensor): the row index of positive items for each user. + positive_i(Torch.Tensor): the positive item id for each user. """ - if self.register.need('rec.topk'): - - user_len_list = interaction.user_len_list - pos_len_list = interaction.pos_len_list - - scores_matrix = self._get_score_matrix(scores_tensor, user_len_list) - scores_matrix = torch.flip(scores_matrix, dims=[-1]) - shape_matrix = torch.full((len(user_len_list), 1), scores_matrix.shape[1], device=scores_matrix.device) + if self.register.need('rec.items'): - pos_len_matrix = torch.from_numpy(np.array(pos_len_list)).view(-1, 1).to(scores_matrix.device) + # get topk + _, topk_idx = torch.topk(scores_tensor, max(self.topk), dim=-1) # n_users x k + self.data_struct.update_tensor('rec.items', topk_idx) - assert pos_len_matrix.shape[0] == shape_matrix.shape[0] + if self.register.need('rec.topk'): - # get topk - _, topk_idx = torch.topk(scores_matrix, max(self.topk), dim=-1) # n_users x k - self.topk_idx = topk_idx - # pack top_idx and shape_matrix - result = torch.cat((topk_idx, shape_matrix, pos_len_matrix), dim=1) + _, topk_idx = torch.topk(scores_tensor, max(self.topk), dim=-1) # n_users x k + pos_matrix = torch.zeros_like(scores_tensor, dtype=torch.int) + pos_matrix[positive_u, positive_i] = 1 + pos_len_list = pos_matrix.sum(dim=1, keepdim=True) + pos_idx = torch.gather(pos_matrix, dim=1, index=topk_idx) + result = torch.cat((pos_idx, pos_len_list), dim=1) self.data_struct.update_tensor('rec.topk', result) if self.register.need('rec.meanrank'): - user_len_list = interaction.user_len_list - pos_len_list = interaction.pos_len_list - pos_len_tensor = torch.Tensor(pos_len_list).to(scores_tensor.device) - scores_matrix = self._get_score_matrix(scores_tensor, user_len_list) - desc_scores, desc_index = torch.sort(scores_matrix, dim=-1, descending=True) + desc_scores, desc_index = torch.sort(scores_tensor, dim=-1, descending=True) # get the index of positive items in the ranking list - pos_index = (desc_index < pos_len_tensor.reshape(-1, 1)) + pos_matrix = torch.zeros_like(scores_tensor) + pos_matrix[positive_u, positive_i] = 1 + pos_index = torch.gather(pos_matrix, dim=1, index=desc_index) avg_rank = self._average_rank(desc_scores) - pos_rank_sum = torch.where(pos_index, avg_rank, torch.zeros_like(avg_rank)).sum(axis=-1).reshape(-1, 1) - - pos_len_matrix = torch.from_numpy(np.array(pos_len_list)).view(-1, 1).to(scores_matrix.device) - user_len_matrix = torch.from_numpy(np.array(user_len_list)).view(-1, 1).to(scores_matrix.device) + pos_rank_sum = torch.where(pos_index == 1, avg_rank, torch.zeros_like(avg_rank)).sum(dim=-1, keepdim=True) - result = torch.cat((pos_rank_sum, user_len_matrix, pos_len_matrix), dim=1) + pos_len_list = pos_matrix.sum(dim=1, keepdim=True) + user_len_list = desc_scores.argmin(dim=1, keepdim=True) + result = torch.cat((pos_rank_sum, user_len_list, pos_len_list), dim=1) self.data_struct.update_tensor('rec.meanrank', result) if self.register.need('rec.score'): @@ -191,21 +170,7 @@ def eval_batch_collect(self, scores_tensor: torch.Tensor, interaction): if self.register.need('data.label'): self.label_field = self.config['LABEL_FIELD'] - self.data_struct.update_tensor('data.label', interaction[self.label_field].to(scores_tensor.device)) - - if self.register.need('rec.items'): - if not self.register.need('rec.topk'): - raise ValueError("Recommended items is only prepared for top-k metrics!") - if self.full: - self.data_struct.update_tensor('rec.items', self.topk_idx) - else: - self.item_field = self.config['ITEM_ID_FIELD'] - user_len_list = interaction.user_len_list - item_tensor = interaction[self.item_field].to(scores_tensor.device) - item_matrix = self._get_score_matrix(item_tensor, user_len_list) # n_user * n_items - topk_item = torch.gather(item_matrix, dim=1, index=self.topk_idx) # n_user * k - - self.data_struct.update_tensor('rec.items', topk_item) + self.data_struct.update_tensor('data.label', interaction[self.label_field].to(self.device)) def model_collect(self, model: torch.nn.Module): @@ -228,7 +193,7 @@ def eval_collect(self, eval_pred: torch.Tensor, data_label: torch.Tensor): if self.register.need('data.label'): self.label_field = self.config['LABEL_FIELD'] - self.data_struct.update_tensor('data.label', data_label.to(eval_pred.device)) + self.data_struct.update_tensor('data.label', data_label.to(self.device)) def get_data_struct(self): """ Get all the evaluation resource that been collected. diff --git a/recbole/evaluator/register.py b/recbole/evaluator/register.py index 66a275d26..99dd1ab91 100644 --- a/recbole/evaluator/register.py +++ b/recbole/evaluator/register.py @@ -3,7 +3,7 @@ # @Email : zhlin@ruc.edu.cn # UPDATE -# @Time : 2021/7/5 +# @Time : 2021/7/18 # @Author : Zhichao Feng # @email : fzcbupt@gmail.com @@ -21,11 +21,11 @@ 'precision': ['rec.topk'], 'map': ['rec.topk'], - 'itemcoverage': ['rec.topk', 'rec.items', 'data.num_items'], # Sign in for topk non-accuracy metrics - 'averagepopularity': ['rec.topk', 'rec.items', 'data.count_items'], - 'giniindex': ['rec.topk', 'rec.items', 'data.num_items'], - 'shannonentropy': ['rec.topk', 'rec.items'], - 'tailpercentage': ['rec.topk', 'rec.items', 'data.count_items'], + 'itemcoverage': ['rec.items', 'data.num_items'], # Sign in for topk non-accuracy metrics + 'averagepopularity': ['rec.items', 'data.count_items'], + 'giniindex': ['rec.items', 'data.num_items'], + 'shannonentropy': ['rec.items'], + 'tailpercentage': ['rec.items', 'data.count_items'], 'gauc': ['rec.meanrank'], # Sign in for full ranking metrics diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index b69d2fdab..8fab9ebb9 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -8,9 +8,9 @@ # @Email : zhlin@ruc.edu.cn, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, slmu@ruc.edu.cn, panxy@ruc.edu.cn # UPDATE: -# @Time : 2020/10/8, 2020/10/15, 2020/11/20, 2021/2/20, 2021/3/3, 2021/3/5 -# @Author : Hui Wang, Xinyan Fan, Chen Yang, Yibo Li, Lanling Xu, Haoran Cheng -# @Email : hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn, 254170321@qq.com, 2018202152@ruc.edu.cn, xulanling_sherry@163.com, chenghaoran29@foxmail.com +# @Time : 2020/10/8, 2020/10/15, 2020/11/20, 2021/2/20, 2021/3/3, 2021/3/5, 2021/7/18 +# @Author : Hui Wang, Xinyan Fan, Chen Yang, Yibo Li, Lanling Xu, Haoran Cheng, Zhichao Feng +# @Email : hui.wang@ruc.edu.cn, xinyan.fan@ruc.edu.cn, 254170321@qq.com, 2018202152@ruc.edu.cn, xulanling_sherry@163.com, chenghaoran29@foxmail.com, fzcbupt@gmail.com r""" recbole.trainer.trainer @@ -30,7 +30,7 @@ from recbole.data.interaction import Interaction from recbole.evaluator import Evaluator, Collector from recbole.utils import ensure_dir, get_local_time, early_stopping, calculate_valid_score, dict2str, \ - DataLoaderType, KGDataLoaderState, get_tensorboard, set_color + DataLoaderType, EvaluatorType, KGDataLoaderState, get_tensorboard, set_color class AbstractTrainer(object): @@ -339,7 +339,7 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre return self.best_valid_score, self.best_valid_result def _full_sort_batch_eval(self, batched_data): - interaction, history_index, swap_row, swap_col_after, swap_col_before = batched_data + interaction, history_index, positive_u, positive_i = batched_data try: # Note: interaction without item ids scores = self.model.full_sort_predict(interaction.to(self.device)) @@ -357,12 +357,24 @@ def _full_sort_batch_eval(self, batched_data): if history_index is not None: scores[history_index] = -np.inf - swap_row = swap_row.to(self.device) - swap_col_after = swap_col_after.to(self.device) - swap_col_before = swap_col_before.to(self.device) - scores[swap_row, swap_col_after] = scores[swap_row, swap_col_before] + return interaction, scores, positive_u, positive_i - return interaction, scores + def _neg_sample_batch_eval(self, batched_data): + interaction, row_idx, positive_u, positive_i = batched_data + batch_size = interaction.length + if batch_size <= self.test_batch_size: + origin_scores = self.model.predict(interaction.to(self.device)) + else: + origin_scores = self._spilt_predict(interaction, batch_size) + + if self.config['eval_type'] == EvaluatorType.INDIVIDUAL: + return interaction, origin_scores, positive_u, positive_i + elif self.config['eval_type'] == EvaluatorType.RANKING: + col_idx = interaction[self.config['ITEM_ID_FIELD']] + batch_user_num = positive_u[-1] + 1 + scores = torch.full((batch_user_num, self.tot_item_num), -np.inf, device=self.device) + scores[row_idx, col_idx] = origin_scores + return interaction, scores, positive_u, positive_i @torch.no_grad() def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progress=False): @@ -397,7 +409,7 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progre if eval_data.dl_type == DataLoaderType.FULL: if self.item_tensor is None: self.item_tensor = eval_data.dataset.get_item_feature().to(self.device).repeat(eval_data.step) - self.tot_item_num = eval_data.dataset.item_num + self.tot_item_num = eval_data.dataset.item_num batch_matrix_list = [] iter_data = ( @@ -409,16 +421,10 @@ def evaluate(self, eval_data, load_best_model=True, model_file=None, show_progre ) for batch_idx, batched_data in iter_data: if eval_data.dl_type == DataLoaderType.FULL: - interaction, scores = self._full_sort_batch_eval(batched_data) + interaction, scores, positive_u, positive_i = self._full_sort_batch_eval(batched_data) else: - interaction = batched_data - batch_size = interaction.length - if batch_size <= self.test_batch_size: - scores = self.model.predict(interaction.to(self.device)) - else: - scores = self._spilt_predict(interaction, batch_size) - - self.eval_collector.eval_batch_collect(scores, interaction) + interaction, scores, positive_u, positive_i = self._neg_sample_batch_eval(batched_data) + self.eval_collector.eval_batch_collect(scores, interaction, positive_u, positive_i) self.eval_collector.model_collect(self.model) struct = self.eval_collector.get_data_struct() result = self.evaluator.evaluate(struct) diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index 6f82b357f..0abce1382 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -4,9 +4,9 @@ # @Email : chenyushuo@ruc.edu.cn # UPDATE -# @Time : 2020/1/5, 2021/7/1 -# @Author : Yushuo Chen, Xingyu Pan -# @email : chenyushuo@ruc.edu.cn, xy_pan@foxmail.com +# @Time : 2020/1/5, 2021/7/1, 2021/7/19 +# @Author : Yushuo Chen, Xingyu Pan, Zhichao Feng +# @email : chenyushuo@ruc.edu.cn, xy_pan@foxmail.com, fzcbupt@gmail.com import logging import os @@ -44,15 +44,19 @@ def test_general_dataloader(self): } train_data, valid_data, test_data = new_dataloader(config_dict=config_dict) - def check_dataloader(data, item_list, batch_size): + def check_dataloader(data, item_list, batch_size, train=False): data.shuffle = False pr = 0 for batch_data in data: batch_item_list = item_list[pr: pr + batch_size] - assert (batch_data['item_id'].numpy() == batch_item_list).all() + if train: + user_df = batch_data + else: + user_df = batch_data[0] + assert (user_df['item_id'].numpy() == batch_item_list).all() pr += batch_size - check_dataloader(train_data, list(range(1, 41)), train_batch_size) + check_dataloader(train_data, list(range(1, 41)), train_batch_size, True) check_dataloader(valid_data, list(range(41, 46)), max(eval_batch_size, 5)) check_dataloader(test_data, list(range(46, 51)), max(eval_batch_size, 5)) @@ -128,59 +132,43 @@ def test_general_full_dataloader(self): def check_result(data, result): assert len(data) == len(result) for i, batch_data in enumerate(data): - user_df, history_index, swap_row, swap_col_after, swap_col_before = batch_data + user_df, history_index, positive_u, positive_i = batch_data history_row, history_col = history_index assert len(user_df) == result[i]['len_user_df'] assert (user_df['user_id'].numpy() == result[i]['user_df_user_id']).all() - assert (user_df.pos_len_list == result[i]['pos_len_list']).all() - assert (user_df.user_len_list == result[i]['user_len_list']).all() assert len(history_row) == len(history_col) == result[i]['history_len'] assert (history_row.numpy() == result[i]['history_row']).all() assert (history_col.numpy() == result[i]['history_col']).all() - assert len(swap_row) == len(swap_col_after) == len(swap_col_before) == result[i]['swap_len'] - assert (swap_row.numpy() == result[i]['swap_row']).all() - assert (swap_col_after.numpy() == result[i]['swap_col_after']).all() - assert (swap_col_before.numpy() == result[i]['swap_col_before']).all() + assert (positive_u.numpy() == result[i]['positive_u']).all() + assert (positive_i.numpy() == result[i]['positive_i']).all() valid_result = [ { 'len_user_df': 1, 'user_df_user_id': [1], - 'pos_len_list': [5], - 'user_len_list': [101], 'history_len': 40, 'history_row': 0, 'history_col': list(range(1, 41)), - 'swap_len': 10, - 'swap_row': 0, - 'swap_col_after': [0, 1, 2, 3, 4, 41, 42, 43, 44, 45], - 'swap_col_before': [45, 44, 43, 42, 41, 4, 3, 2, 1, 0], + 'positive_u': [0, 0, 0, 0, 0], + 'positive_i': [41, 42, 43, 44, 45] }, { 'len_user_df': 1, 'user_df_user_id': [2], - 'pos_len_list': [5], - 'user_len_list': [101], 'history_len': 37, 'history_row': 0, 'history_col': list(range(1, 38)), - 'swap_len': 10, - 'swap_row': 0, - 'swap_col_after': [0, 1, 2, 3, 4, 38, 39, 40, 41, 42], - 'swap_col_before': [42, 41, 40, 39, 38, 4, 3, 2, 1, 0], + 'positive_u': [0, 0, 0, 0, 0], + 'positive_i': [38, 39, 40, 41, 42] }, { 'len_user_df': 1, 'user_df_user_id': [3], - 'pos_len_list': [1], - 'user_len_list': [101], 'history_len': 0, 'history_row': [], 'history_col': [], - 'swap_len': 2, - 'swap_row': 0, - 'swap_col_after': [0, 1], - 'swap_col_before': [1, 0], + 'positive_u': [0], + 'positive_i': [1] }, ] check_result(valid_data, valid_result) @@ -189,41 +177,29 @@ def check_result(data, result): { 'len_user_df': 1, 'user_df_user_id': [1], - 'pos_len_list': [5], - 'user_len_list': [101], 'history_len': 45, 'history_row': 0, 'history_col': list(range(1, 46)), - 'swap_len': 10, - 'swap_row': 0, - 'swap_col_after': [0, 1, 2, 3, 4, 46, 47, 48, 49, 50], - 'swap_col_before': [50, 49, 48, 47, 46, 4, 3, 2, 1, 0], + 'positive_u': [0, 0, 0, 0, 0], + 'positive_i': [46, 47, 48, 49, 50] }, { 'len_user_df': 1, 'user_df_user_id': [2], - 'pos_len_list': [5], - 'user_len_list': [101], 'history_len': 37, 'history_row': 0, 'history_col': list(range(1, 36)) + [41, 42], - 'swap_len': 10, - 'swap_row': 0, - 'swap_col_after': [0, 1, 2, 3, 4, 36, 37, 38, 39, 40], - 'swap_col_before': [40, 39, 38, 37, 36, 4, 3, 2, 1, 0], + 'positive_u': [0, 0, 0, 0, 0], + 'positive_i': [36, 37, 38, 39, 40] }, { 'len_user_df': 1, 'user_df_user_id': [3], - 'pos_len_list': [1], - 'user_len_list': [101], 'history_len': 0, 'history_row': [], 'history_col': [], - 'swap_len': 2, - 'swap_row': 0, - 'swap_col_after': [0, 1], - 'swap_col_before': [1, 0], + 'positive_u': [0], + 'positive_i': [1] }, ] check_result(test_data, test_result) @@ -247,30 +223,35 @@ def check_result(data, result): assert data.batch_size == 202 assert len(data) == len(result) for i, batch_data in enumerate(data): - assert result[i]['item_id_check'](batch_data['item_id']) - assert batch_data.pos_len_list == result[i]['pos_len_list'] - assert batch_data.user_len_list == result[i]['user_len_list'] + user_df, row_idx, positive_u, positive_i = batch_data + assert result[i]['item_id_check'](user_df['item_id']) + assert (row_idx.numpy() == result[i]['row_idx']).all() + assert (positive_u.numpy() == result[i]['positive_u']).all() + assert (positive_i.numpy() == result[i]['positive_i']).all() valid_result = [ { 'item_id_check': lambda data: data[0] == 9 and (8 < data[1:]).all() and (data[1:] <= 100).all(), - 'pos_len_list': [1], - 'user_len_list': [101], + 'row_idx': [0] * 101, + 'positive_u': [0], + 'positive_i': [9], }, { 'item_id_check': lambda data: data[0] == 1 and (data[1:] != 1).all(), - 'pos_len_list': [1], - 'user_len_list': [101], + 'row_idx': [0] * 101, + 'positive_u': [0], + 'positive_i': [1], }, { 'item_id_check': lambda data: (data[0: 2].numpy() == [17, 18]).all() and (16 < data[2:]).all() and (data[2:] <= 100).all(), - 'pos_len_list': [2], - 'user_len_list': [202], + 'row_idx': [0] * 202, + 'positive_u': [0, 0], + 'positive_i': [17, 18], }, ] check_result(valid_data, valid_result) @@ -280,21 +261,24 @@ def check_result(data, result): 'item_id_check': lambda data: data[0] == 10 and (9 < data[1:]).all() and (data[1:] <= 100).all(), - 'pos_len_list': [1], - 'user_len_list': [101], + 'row_idx': [0] * 101, + 'positive_u': [0], + 'positive_i': [10], }, { 'item_id_check': lambda data: data[0] == 1 and (data[1:] != 1).all(), - 'pos_len_list': [1], - 'user_len_list': [101], + 'row_idx': [0] * 101, + 'positive_u': [0], + 'positive_i': [1], }, { 'item_id_check': lambda data: (data[0: 2].numpy() == [19, 20]).all() and (18 < data[2:]).all() and (data[2:] <= 100).all(), - 'pos_len_list': [2], - 'user_len_list': [202], + 'row_idx': [0] * 202, + 'positive_u': [0, 0], + 'positive_i': [19, 20], }, ] check_result(test_data, test_result) @@ -318,9 +302,11 @@ def check_result(data, result): assert data.batch_size == 303 assert len(data) == len(result) for i, batch_data in enumerate(data): - assert result[i]['item_id_check'](batch_data['item_id']) - assert batch_data.pos_len_list == result[i]['pos_len_list'] - assert batch_data.user_len_list == result[i]['user_len_list'] + user_df, row_idx, positive_u, positive_i = batch_data + assert result[i]['item_id_check'](user_df['item_id']) + assert (row_idx.numpy() == result[i]['row_idx']).all() + assert (positive_u.numpy() == result[i]['positive_u']).all() + assert (positive_i.numpy() == result[i]['positive_i']).all() valid_result = [ { @@ -329,15 +315,17 @@ def check_result(data, result): and (data[1: 101] <= 100).all() and data[101] == 1 and (data[102:202] != 1).all(), - 'pos_len_list': [1, 1], - 'user_len_list': [101, 101], + 'row_idx': [0] * 101 + [1] * 101, + 'positive_u': [0, 1], + 'positive_i': [9, 1], }, { 'item_id_check': lambda data: (data[0: 2].numpy() == [17, 18]).all() and (16 < data[2:]).all() and (data[2:] <= 100).all(), - 'pos_len_list': [2], - 'user_len_list': [202], + 'row_idx': [0] * 202, + 'positive_u': [0, 0], + 'positive_i': [17, 18], }, ] check_result(valid_data, valid_result) @@ -349,15 +337,17 @@ def check_result(data, result): and (data[1:101] <= 100).all() and data[101] == 1 and (data[102:202] != 1).all(), - 'pos_len_list': [1, 1], - 'user_len_list': [101, 101], + 'row_idx': [0] * 101 + [1] * 101, + 'positive_u': [0, 1], + 'positive_i': [10, 1], }, { 'item_id_check': lambda data: (data[0: 2].numpy() == [19, 20]).all() and (18 < data[2:]).all() and (data[2:] <= 100).all(), - 'pos_len_list': [2], - 'user_len_list': [202], + 'row_idx': [0] * 202, + 'positive_u': [0, 0], + 'positive_i': [19, 20], }, ] check_result(test_data, test_result)