Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: interface between dataloader and evaluator #894

Merged
merged 13 commits into from
Jul 19, 2021
9 changes: 5 additions & 4 deletions recbole/data/dataloader/abstract_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,18 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args):
self.iid_field = dataset.iid_field
self.dl_format = dl_format
self.neg_sample_args = neg_sample_args
self.times = 1
if self.neg_sample_args['strategy'] == 'by':
self.neg_sample_by = self.neg_sample_args['by']
self.neg_sample_num = self.neg_sample_args['by']

if self.dl_format == InputType.POINTWISE:
self.times = 1 + self.neg_sample_by
self.times = 1 + self.neg_sample_num
self.sampling_func = self._neg_sample_by_point_wise_sampling

self.label_field = config['LABEL_FIELD']
dataset.set_field_property(self.label_field, FeatureType.FLOAT, FeatureSource.INTERACTION, 1)
elif self.dl_format == InputType.PAIRWISE:
self.times = self.neg_sample_by
self.times = self.neg_sample_num
self.sampling_func = self._neg_sample_by_pair_wise_sampling

self.neg_prefix = config['NEG_PREFIX']
Expand All @@ -147,7 +148,7 @@ def _neg_sampling(self, inter_feat):
if self.neg_sample_args['strategy'] == 'by':
user_ids = inter_feat[self.uid_field]
item_ids = inter_feat[self.iid_field]
neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_by)
neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num)
return self.sampling_func(inter_feat, neg_item_ids)
else:
return inter_feat
Expand Down
95 changes: 29 additions & 66 deletions recbole/data/dataloader/general_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,30 +120,24 @@ def _shuffle(self):
def _next_batch_data(self):
uid_list = self.uid_list[self.pr:self.pr + self.step]
data_list = []
for uid in uid_list:
idx_list = []
positive_u = []
positive_i = torch.tensor([], dtype=torch.int64)

for idx, uid in enumerate(uid_list):
index = self.uid2index[uid]
data_list.append(self._neg_sampling(self.dataset[index]))
cur_data = cat_interactions(data_list)
if self.neg_sample_args['strategy'] == 'by':
pos_len_list = self.uid2items_num[uid_list]
user_len_list = pos_len_list * self.times
cur_data.set_additional_info(list(pos_len_list), list(user_len_list))
self.pr += self.step
return cur_data
idx_list += [idx for i in range(self.uid2items_num[uid] * self.times)]
positive_u += [idx for i in range(self.uid2items_num[uid])]
positive_i = torch.cat((positive_i, self.dataset[index][self.iid_field]), 0)

def get_pos_len_list(self):
"""
Returns:
numpy.ndarray: Number of positive item for each user in a training/evaluating epoch.
"""
return self.uid2items_num[self.uid_list]
cur_data = cat_interactions(data_list)
idx_list = torch.from_numpy(np.array(idx_list))
positive_u = torch.from_numpy(np.array(positive_u))

def get_user_len_list(self):
"""
Returns:
numpy.ndarray: Number of all item for each user in a training/evaluating epoch.
"""
return self.uid2items_num[self.uid_list] * self.times
self.pr += self.step

return cur_data, idx_list, positive_u, positive_i


class FullSortEvalDataLoader(AbstractDataLoader):
Expand All @@ -167,8 +161,7 @@ def __init__(self, config, dataset, sampler, shuffle=False):
user_num = dataset.user_num
self.uid_list = []
self.uid2items_num = np.zeros(user_num, dtype=np.int64)
self.uid2swap_idx = np.array([None] * user_num)
self.uid2rev_swap_idx = np.array([None] * user_num)
self.uid2positive_item = np.array([None] * user_num)
self.uid2history_item = np.array([None] * user_num)

dataset.sort(by=self.uid_field, ascending=True)
Expand All @@ -192,11 +185,8 @@ def _set_user_property(self, uid, used_item, positive_item):
if uid is None:
return
history_item = used_item - positive_item
positive_item_num = len(positive_item)
self.uid2items_num[uid] = positive_item_num
swap_idx = torch.tensor(sorted(set(range(positive_item_num)) ^ positive_item))
self.uid2swap_idx[uid] = swap_idx
self.uid2rev_swap_idx[uid] = swap_idx.flip(0)
self.uid2positive_item[uid] = torch.tensor(list(positive_item), dtype=torch.int64)
self.uid2items_num[uid] = len(positive_item)
self.uid2history_item[uid] = torch.tensor(list(history_item), dtype=torch.int64)

def _init_batch_size_and_step(self):
Expand All @@ -223,51 +213,24 @@ def _shuffle(self):
def _next_batch_data(self):
if not self.is_sequential:
user_df = self.user_df[self.pr:self.pr + self.step]
uid_list = list(user_df[self.dataset.uid_field])
pos_len_list = self.uid2items_num[uid_list]
user_len_list = np.full(len(uid_list), self.dataset.item_num)
user_df.set_additional_info(pos_len_list, user_len_list)

uid_list = list(user_df[self.uid_field])

history_item = self.uid2history_item[uid_list]
history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)])
history_col = torch.cat(list(history_item))
positive_item = self.uid2positive_item[uid_list]

history_u = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)])
history_i = torch.cat(list(history_item))

swap_idx = self.uid2swap_idx[uid_list]
rev_swap_idx = self.uid2rev_swap_idx[uid_list]
swap_row = torch.cat([torch.full_like(swap, i) for i, swap in enumerate(swap_idx)])
swap_col_after = torch.cat(list(swap_idx))
swap_col_before = torch.cat(list(rev_swap_idx))
positive_u = torch.cat([torch.full_like(pos_iid, i) for i, pos_iid in enumerate(positive_item)])
positive_i = torch.cat(list(positive_item))

self.pr += self.step
return user_df, (history_row, history_col), swap_row, swap_col_after, swap_col_before
return user_df, (history_u, history_i), positive_u, positive_i
else:
interaction = self.dataset[self.pr:self.pr + self.step]
inter_num = len(interaction)
pos_len_list = np.ones(inter_num, dtype=np.int64)
user_len_list = np.full(inter_num, self.dataset.item_num)
interaction.set_additional_info(pos_len_list, user_len_list)
scores_row = torch.arange(inter_num).repeat(2)
padding_idx = torch.zeros(inter_num, dtype=torch.int64)
positive_idx = interaction[self.iid_field]
scores_col_after = torch.cat((padding_idx, positive_idx))
scores_col_before = torch.cat((positive_idx, padding_idx))
positive_u = torch.arange(inter_num)
positive_i = interaction[self.iid_field]

self.pr += self.step
return interaction, None, scores_row, scores_col_after, scores_col_before

def get_pos_len_list(self):
"""
Returns:
numpy.ndarray: Number of positive item for each user in a training/evaluating epoch.
"""
if not self.is_sequential:
return self.uid2items_num[self.uid_list]
else:
return np.ones(self.pr_end, dtype=np.int64)

def get_user_len_list(self):
"""
Returns:
numpy.ndarray: Number of all item for each user in a training/evaluating epoch.
"""
return np.full(self.pr_end, self.dataset.item_num)
return interaction, None, positive_u, positive_i
136 changes: 0 additions & 136 deletions recbole/evaluator/abstract_evaluator.py

This file was deleted.

20 changes: 8 additions & 12 deletions recbole/evaluator/base_metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# -*- encoding: utf-8 -*-
# @Time : 2020/10/21
# @Author : Kaiyuan Li
# @email : tsotfsk@outlook.com
# @Time : 2020/10/21
# @Author : Kaiyuan Li
# @email : tsotfsk@outlook.com

# UPDATE
# @Time : 2020/10/21, 2021/6/25
# @Author : Kaiyuan Li, Zhichao Feng
# @email : tsotfsk@outlook.com, fzcbupt@gmail.com
# @Time : 2020/10/21, 2021/7/18
# @Author : Kaiyuan Li, Zhichao Feng
# @email : tsotfsk@outlook.com, fzcbupt@gmail.com

"""
recbole.evaluator.abstract_metric
Expand All @@ -26,16 +25,13 @@ class TopkMetric(object):
"""
def __init__(self, config):
self.topk = config['topk']
self.indice = [max(self.topk), 1, 1]
self.decimal_place = config['metric_decimal_place']

def used_info(self, dataobject):
"""get the bool matrix indicating whether the corresponding item is positive"""
rec_mat = dataobject.get('rec.topk')
topk_idx, shapes, pos_len_list = torch.split(rec_mat, self.indice, dim=1)
pos_idx_matrix = (topk_idx >= (shapes - pos_len_list).reshape(-1, 1))

return pos_idx_matrix.numpy(), pos_len_list.squeeze().numpy()
topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1)
return rec_mat.to(torch.bool).numpy(), pos_len_list.squeeze().numpy()

def topk_result(self, metric, value):
"""match the metric value to the `k` and put them in `dictionary` form"""
Expand Down
Loading