Skip to content

Commit

Permalink
Merge pull request #20 from RUCAIBox/data
Browse files Browse the repository at this point in the history
get latest code from branch: Data
  • Loading branch information
linzihan-backforward authored Jul 27, 2021
2 parents ccf7cfa + 77f6a3d commit 6c875d1
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 108 deletions.
6 changes: 3 additions & 3 deletions recbole/data/dataloader/general_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE
# @Time : 2020/9/9, 2020/9/29
# @Author : Yupeng Hou, Yushuo Chen
# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn
# @Time : 2020/9/9, 2020/9/29, 2021/7/15
# @Author : Yupeng Hou, Yushuo Chen, Xingyu Pan
# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, xy_pan@foxmail.com

"""
recbole.data.dataloader.general_dataloader
Expand Down
227 changes: 122 additions & 105 deletions recbole/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# @File : sampler.py

# UPDATE
# @Time : 2020/8/17, 2020/8/31, 2020/10/6, 2020/9/18, 2021/3/19
# @Time : 2021/7/23, 2020/8/31, 2020/10/6, 2020/9/18, 2021/3/19
# @Author : Xingyu Pan, Kaiyuan Li, Yupeng Hou, Yushuo Chen, Zhichao Feng
# @email : panxy@ruc.edu.cn, tsotfsk@outlook.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, fzcbupt@gmail.com
# @email : xy_pan@foxmail.com, tsotfsk@outlook.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, fzcbupt@gmail.com

"""
recbole.sampler
Expand All @@ -16,8 +16,9 @@
import copy

import numpy as np
from numpy.random import sample
import torch

from collections import Counter

class AbstractSampler(object):
""":class:`AbstractSampler` is a abstract class, all sampler should inherit from it. This sampler supports returning
Expand All @@ -36,9 +37,6 @@ class AbstractSampler(object):

def __init__(self, distribution):
self.distribution = ''
self.random_list = []
self.random_pr = 0
self.random_list_length = 0
self.set_distribution(distribution)
self.used_ids = self.get_used_ids()

Expand All @@ -48,57 +46,100 @@ def set_distribution(self, distribution):
Args:
distribution (str): Distribution of the negative items.
"""
if self.distribution == distribution:
return
self.distribution = distribution
self.random_list = self.get_random_list()
np.random.shuffle(self.random_list)
self.random_pr = 0
self.random_list_length = len(self.random_list)

def _uni_sampling(self, sample_num):
"""Sample [sample_num] items in the uniform distribution.
def get_random_list(self):
"""
Args:
sample_num (int): the number of samples.
Returns:
numpy.ndarray or list: Random list of value_id.
sample_list (np.array): a list of samples.
"""
raise NotImplementedError('method [get_random_list] should be implemented')
raise NotImplementedError('Method [_uni_sampling] should be implemented')

def _get_candidates_list(self):
"""Get sample candidates list
def get_used_ids(self):
"""
Returns:
numpy.ndarray: Used ids. Index is key_id, and element is a set of value_ids.
candidates_list (list): a list of candidates id.
"""
raise NotImplementedError('method [get_used_ids] should be implemented')
raise NotImplementedError('Method [_get_candidates_list] should be implemented')

def random(self):
def _build_alias_table(self):
"""Build alias table for popularity_biased sampling.
"""
candidates_list = self._get_candidates_list()
self.prob = dict(Counter(candidates_list))
self.alias = self.prob.copy()
large_q = []
small_q = []

for i in self.prob:
self.alias[i] = -1
self.prob[i] = self.prob[i] / len(candidates_list) * len(self.prob)
if self.prob[i] > 1:
large_q.append(i)
elif self.prob[i] < 1:
small_q.append(i)

while(len(large_q)!=0 and len(small_q)!=0):
l = large_q.pop(0)
s = small_q.pop(0)
self.alias[s] = l
self.prob[l] = self.prob[l] - (1 - self.prob[s])
if self.prob[l] < 1:
small_q.append(l)
elif self.prob[l] > 1:
large_q.append(l)

def _pop_sampling(self, sample_num):
"""Sample [sample_num] items in the popularity-biased distribution.
Args:
sample_num (int): the number of samples.
Returns:
value_id (int): Random value_id. Generated by :attr:`random_list`.
sample_list (np.array): a list of samples.
"""
value_id = self.random_list[self.random_pr % self.random_list_length]
self.random_pr += 1
return value_id
self._build_alias_table()

def random_num(self, num):
"""
keys = list(self.prob.keys())
random_index_list = np.random.randint(0, len(keys), sample_num)
random_prob_list = np.random.random(sample_num)
final_random_list = []

for idx, prob in zip(random_index_list, random_prob_list):
if self.prob[keys[idx]] > prob:
final_random_list.append(keys[idx])
else:
final_random_list.append(self.alias[keys[idx]])

return np.array(final_random_list)

def sampling(self, sample_num):
"""Sampling [sample_num] item_ids.
Args:
num (int): Number of random value_ids.
sample_num (int): the number of samples.
Returns:
sample_list (np.array): a list of samples and the len is [sample_num].
"""
if self.distribution =='uniform':
return self._uni_sampling(sample_num)
elif self.distribution == 'popularity':
return self._pop_sampling(sample_num)
else:
raise NotImplementedError(f'The sampling distribution [{self.distribution}] is not implemented.')

def get_used_ids(self):
"""
Returns:
value_ids (numpy.ndarray): Random value_ids. Generated by :attr:`random_list`.
numpy.ndarray: Used ids. Index is key_id, and element is a set of value_ids.
"""
value_id = []
self.random_pr %= self.random_list_length
while True:
if self.random_pr + num <= self.random_list_length:
value_id.append(self.random_list[self.random_pr:self.random_pr + num])
self.random_pr += num
break
else:
value_id.append(self.random_list[self.random_pr:])
num -= self.random_list_length - self.random_pr
self.random_pr = 0
return np.concatenate(value_id)
raise NotImplementedError('Method [get_used_ids] should be implemented')

def sample_by_key_ids(self, key_ids, num):
"""Sampling by key_ids.
Expand All @@ -120,33 +161,18 @@ def sample_by_key_ids(self, key_ids, num):
if (key_ids == key_ids[0]).all():
key_id = key_ids[0]
used = np.array(list(self.used_ids[key_id]))
value_ids = self.random_num(total_num)
value_ids = self.sampling(total_num)
check_list = np.arange(total_num)[np.isin(value_ids, used)]
while len(check_list) > 0:
value_ids[check_list] = value = self.random_num(len(check_list))
perm = value.argsort(kind='quicksort')
aux = value[perm]
mask = np.empty(aux.shape, dtype=np.bool_)
mask[:1] = True
mask[1:] = aux[1:] != aux[:-1]
value = aux[mask]
rev_idx = np.empty(mask.shape, dtype=np.intp)
rev_idx[perm] = np.cumsum(mask) - 1
ar = np.concatenate((value, used))
order = ar.argsort(kind='mergesort')
sar = ar[order]
bool_ar = (sar[1:] == sar[:-1])
flag = np.concatenate((bool_ar, [False]))
ret = np.empty(ar.shape, dtype=bool)
ret[order] = flag
mask = ret[rev_idx]
value_ids[check_list] = value = self.sampling(len(check_list))
mask = np.isin(value, used)
check_list = check_list[mask]
else:
value_ids = np.zeros(total_num, dtype=np.int64)
check_list = np.arange(total_num)
key_ids = np.tile(key_ids, num)
while len(check_list) > 0:
value_ids[check_list] = self.random_num(len(check_list))
value_ids[check_list] = self.sampling(len(check_list))
check_list = np.array([
i for i, used, v in zip(check_list, self.used_ids[key_ids[check_list]], value_ids[check_list])
if v in used
Expand Down Expand Up @@ -183,25 +209,20 @@ def __init__(self, phases, datasets, distribution='uniform'):
self.uid_field = datasets[0].uid_field
self.iid_field = datasets[0].iid_field

self.n_users = datasets[0].user_num
self.n_items = datasets[0].item_num
self.user_num = datasets[0].user_num
self.item_num = datasets[0].item_num

super().__init__(distribution=distribution)

def get_random_list(self):
"""
Returns:
numpy.ndarray or list: Random list of item_id.
"""
if self.distribution == 'uniform':
return np.arange(1, self.n_items)
elif self.distribution == 'popularity':
random_item_list = []
for dataset in self.datasets:
random_item_list.extend(dataset.inter_feat[self.iid_field].numpy())
return random_item_list
else:
raise NotImplementedError(f'Distribution [{self.distribution}] has not been implemented.')

def _get_candidates_list(self):
candidates_list = []
for dataset in self.datasets:
candidates_list.extend(dataset.inter_feat[self.iid_field].numpy())
return candidates_list

def _uni_sampling(self, sample_num):
return np.random.randint(1, self.item_num, sample_num)

def get_used_ids(self):
"""
Expand All @@ -210,15 +231,15 @@ def get_used_ids(self):
Key is phase, and value is a numpy.ndarray which index is user_id, and element is a set of item_ids.
"""
used_item_id = dict()
last = [set() for _ in range(self.n_users)]
last = [set() for _ in range(self.user_num)]
for phase, dataset in zip(self.phases, self.datasets):
cur = np.array([set(s) for s in last])
for uid, iid in zip(dataset.inter_feat[self.uid_field].numpy(), dataset.inter_feat[self.iid_field].numpy()):
cur[uid].add(iid)
last = used_item_id[phase] = cur

for used_item_set in used_item_id[self.phases[-1]]:
if len(used_item_set) + 1 == self.n_items: # [pad] is a item.
if len(used_item_set) + 1 == self.item_num: # [pad] is a item.
raise ValueError(
'Some users have interacted with all items, '
'which we can not sample negative items for them. '
Expand Down Expand Up @@ -262,7 +283,7 @@ def sample_by_user_ids(self, user_ids, item_ids, num):
return self.sample_by_key_ids(user_ids, num)
except IndexError:
for user_id in user_ids:
if user_id < 0 or user_id >= self.n_users:
if user_id < 0 or user_id >= self.user_num:
raise ValueError(f'user_id [{user_id}] not exist.')


Expand All @@ -287,17 +308,11 @@ def __init__(self, dataset, distribution='uniform'):

super().__init__(distribution=distribution)

def get_random_list(self):
"""
Returns:
numpy.ndarray or list: Random list of entity_id.
"""
if self.distribution == 'uniform':
return np.arange(1, self.entity_num)
elif self.distribution == 'popularity':
return list(self.hid_list) + list(self.tid_list)
else:
raise NotImplementedError(f'Distribution [{self.distribution}] has not been implemented.')
def _uni_sampling(self, sample_num):
return np.random.randint(1, self.entity_num, sample_num)

def _get_candidates_list(self):
return list(self.hid_list) + list(self.tid_list)

def get_used_ids(self):
"""
Expand Down Expand Up @@ -359,18 +374,24 @@ def __init__(self, phases, dataset, distribution='uniform'):
self.dataset = dataset

self.iid_field = dataset.iid_field
self.n_users = dataset.user_num
self.n_items = dataset.item_num
self.user_num = dataset.user_num
self.item_num = dataset.item_num

super().__init__(distribution=distribution)

def _uni_sampling(self, sample_num):
return np.random.randint(1, self.item_num, sample_num)

def _get_candidates_list(self):
return list(self.dataset.inter_feat[self.iid_field].numpy())

def get_random_list(self):
"""
Returns:
numpy.ndarray or list: Random list of item_id.
"""
if self.distribution == 'uniform':
return np.arange(1, self.n_items)
return np.arange(1, self.item_num)
elif self.distribution == 'popularity':
return self.dataset.inter_feat[self.iid_field].numpy()
else:
Expand All @@ -382,7 +403,7 @@ def get_used_ids(self):
numpy.ndarray: Used item_ids is the same as positive item_ids.
Index is user_id, and element is a set of item_ids.
"""
return np.array([set() for _ in range(self.n_users)])
return np.array([set() for _ in range(self.user_num)])

def sample_by_user_ids(self, user_ids, item_ids, num):
"""Sampling by user_ids.
Expand All @@ -404,7 +425,7 @@ def sample_by_user_ids(self, user_ids, item_ids, num):
return self.sample_by_key_ids(np.arange(len(user_ids)), num)
except IndexError:
for user_id in user_ids:
if user_id < 0 or user_id >= self.n_users:
if user_id < 0 or user_id >= self.user_num:
raise ValueError(f'user_id [{user_id}] not exist.')

def set_phase(self, phase):
Expand Down Expand Up @@ -435,21 +456,17 @@ def __init__(self, dataset, distribution='uniform'):
self.dataset = dataset

self.iid_field = dataset.iid_field
self.n_users = dataset.user_num
self.n_items = dataset.item_num
self.user_num = dataset.user_num
self.item_num = dataset.item_num

super().__init__(distribution=distribution)

def _uni_sampling(self, sample_num):
return np.random.randint(1, self.item_num, sample_num)

def get_used_ids(self):
pass

def get_random_list(self):
"""
Returns:
numpy.ndarray or list: Random list of item_id.
"""
return np.arange(1, self.n_items)

def sample_neg_sequence(self, pos_sequence):
"""For each moment, sampling one item from all the items except the one the user clicked on at that moment.
Expand All @@ -464,7 +481,7 @@ def sample_neg_sequence(self, pos_sequence):
value_ids = np.zeros(total_num, dtype=np.int64)
check_list = np.arange(total_num)
while len(check_list) > 0:
value_ids[check_list] = self.random_num(len(check_list))
value_ids[check_list] = self.sampling(len(check_list))
check_index = np.where(value_ids[check_list] == pos_sequence[check_list])
check_list = check_list[check_index]

Expand Down

0 comments on commit 6c875d1

Please sign in to comment.