Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: refactor the sampler #903

Merged
merged 1 commit into from
Jul 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions recbole/data/dataloader/general_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE
# @Time : 2020/9/9, 2020/9/29
# @Author : Yupeng Hou, Yushuo Chen
# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn
# @Time : 2020/9/9, 2020/9/29, 2021/7/15
# @Author : Yupeng Hou, Yushuo Chen, Xingyu Pan
# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, xy_pan@foxmail.com

"""
recbole.data.dataloader.general_dataloader
Expand Down
227 changes: 122 additions & 105 deletions recbole/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
# @File : sampler.py

# UPDATE
# @Time : 2020/8/17, 2020/8/31, 2020/10/6, 2020/9/18, 2021/3/19
# @Time : 2021/7/23, 2020/8/31, 2020/10/6, 2020/9/18, 2021/3/19
# @Author : Xingyu Pan, Kaiyuan Li, Yupeng Hou, Yushuo Chen, Zhichao Feng
# @email : panxy@ruc.edu.cn, tsotfsk@outlook.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, fzcbupt@gmail.com
# @email : xy_pan@foxmail.com, tsotfsk@outlook.com, houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, fzcbupt@gmail.com

"""
recbole.sampler
Expand All @@ -16,8 +16,9 @@
import copy

import numpy as np
from numpy.random import sample
import torch

from collections import Counter

class AbstractSampler(object):
""":class:`AbstractSampler` is a abstract class, all sampler should inherit from it. This sampler supports returning
Expand All @@ -36,9 +37,6 @@ class AbstractSampler(object):

def __init__(self, distribution):
self.distribution = ''
self.random_list = []
self.random_pr = 0
self.random_list_length = 0
self.set_distribution(distribution)
self.used_ids = self.get_used_ids()

Expand All @@ -48,57 +46,100 @@ def set_distribution(self, distribution):
Args:
distribution (str): Distribution of the negative items.
"""
if self.distribution == distribution:
return
self.distribution = distribution
self.random_list = self.get_random_list()
np.random.shuffle(self.random_list)
self.random_pr = 0
self.random_list_length = len(self.random_list)

def _uni_sampling(self, sample_num):
"""Sample [sample_num] items in the uniform distribution.

def get_random_list(self):
"""
Args:
sample_num (int): the number of samples.

Returns:
numpy.ndarray or list: Random list of value_id.
sample_list (np.array): a list of samples.
"""
raise NotImplementedError('method [get_random_list] should be implemented')
raise NotImplementedError('Method [_uni_sampling] should be implemented')

def _get_candidates_list(self):
"""Get sample candidates list

def get_used_ids(self):
"""
Returns:
numpy.ndarray: Used ids. Index is key_id, and element is a set of value_ids.
candidates_list (list): a list of candidates id.
"""
raise NotImplementedError('method [get_used_ids] should be implemented')
raise NotImplementedError('Method [_get_candidates_list] should be implemented')

def random(self):
def _build_alias_table(self):
"""Build alias table for popularity_biased sampling.
"""
candidates_list = self._get_candidates_list()
self.prob = dict(Counter(candidates_list))
self.alias = self.prob.copy()
large_q = []
small_q = []

for i in self.prob:
self.alias[i] = -1
self.prob[i] = self.prob[i] / len(candidates_list) * len(self.prob)
if self.prob[i] > 1:
large_q.append(i)
elif self.prob[i] < 1:
small_q.append(i)

while(len(large_q)!=0 and len(small_q)!=0):
l = large_q.pop(0)
s = small_q.pop(0)
self.alias[s] = l
self.prob[l] = self.prob[l] - (1 - self.prob[s])
if self.prob[l] < 1:
small_q.append(l)
elif self.prob[l] > 1:
large_q.append(l)

def _pop_sampling(self, sample_num):
"""Sample [sample_num] items in the popularity-biased distribution.

Args:
sample_num (int): the number of samples.

Returns:
value_id (int): Random value_id. Generated by :attr:`random_list`.
sample_list (np.array): a list of samples.
"""
value_id = self.random_list[self.random_pr % self.random_list_length]
self.random_pr += 1
return value_id
self._build_alias_table()

def random_num(self, num):
"""
keys = list(self.prob.keys())
random_index_list = np.random.randint(0, len(keys), sample_num)
random_prob_list = np.random.random(sample_num)
final_random_list = []

for idx, prob in zip(random_index_list, random_prob_list):
if self.prob[keys[idx]] > prob:
final_random_list.append(keys[idx])
else:
final_random_list.append(self.alias[keys[idx]])

return np.array(final_random_list)

def sampling(self, sample_num):
"""Sampling [sample_num] item_ids.

Args:
num (int): Number of random value_ids.
sample_num (int): the number of samples.

Returns:
sample_list (np.array): a list of samples and the len is [sample_num].
"""
if self.distribution =='uniform':
return self._uni_sampling(sample_num)
elif self.distribution == 'popularity':
return self._pop_sampling(sample_num)
else:
raise NotImplementedError(f'The sampling distribution [{self.distribution}] is not implemented.')

def get_used_ids(self):
"""
Returns:
value_ids (numpy.ndarray): Random value_ids. Generated by :attr:`random_list`.
numpy.ndarray: Used ids. Index is key_id, and element is a set of value_ids.
"""
value_id = []
self.random_pr %= self.random_list_length
while True:
if self.random_pr + num <= self.random_list_length:
value_id.append(self.random_list[self.random_pr:self.random_pr + num])
self.random_pr += num
break
else:
value_id.append(self.random_list[self.random_pr:])
num -= self.random_list_length - self.random_pr
self.random_pr = 0
return np.concatenate(value_id)
raise NotImplementedError('Method [get_used_ids] should be implemented')

def sample_by_key_ids(self, key_ids, num):
"""Sampling by key_ids.
Expand All @@ -120,33 +161,18 @@ def sample_by_key_ids(self, key_ids, num):
if (key_ids == key_ids[0]).all():
key_id = key_ids[0]
used = np.array(list(self.used_ids[key_id]))
value_ids = self.random_num(total_num)
value_ids = self.sampling(total_num)
check_list = np.arange(total_num)[np.isin(value_ids, used)]
while len(check_list) > 0:
value_ids[check_list] = value = self.random_num(len(check_list))
perm = value.argsort(kind='quicksort')
aux = value[perm]
mask = np.empty(aux.shape, dtype=np.bool_)
mask[:1] = True
mask[1:] = aux[1:] != aux[:-1]
value = aux[mask]
rev_idx = np.empty(mask.shape, dtype=np.intp)
rev_idx[perm] = np.cumsum(mask) - 1
ar = np.concatenate((value, used))
order = ar.argsort(kind='mergesort')
sar = ar[order]
bool_ar = (sar[1:] == sar[:-1])
flag = np.concatenate((bool_ar, [False]))
ret = np.empty(ar.shape, dtype=bool)
ret[order] = flag
mask = ret[rev_idx]
value_ids[check_list] = value = self.sampling(len(check_list))
mask = np.isin(value, used)
check_list = check_list[mask]
else:
value_ids = np.zeros(total_num, dtype=np.int64)
check_list = np.arange(total_num)
key_ids = np.tile(key_ids, num)
while len(check_list) > 0:
value_ids[check_list] = self.random_num(len(check_list))
value_ids[check_list] = self.sampling(len(check_list))
check_list = np.array([
i for i, used, v in zip(check_list, self.used_ids[key_ids[check_list]], value_ids[check_list])
if v in used
Expand Down Expand Up @@ -183,25 +209,20 @@ def __init__(self, phases, datasets, distribution='uniform'):
self.uid_field = datasets[0].uid_field
self.iid_field = datasets[0].iid_field

self.n_users = datasets[0].user_num
self.n_items = datasets[0].item_num
self.user_num = datasets[0].user_num
self.item_num = datasets[0].item_num

super().__init__(distribution=distribution)

def get_random_list(self):
"""
Returns:
numpy.ndarray or list: Random list of item_id.
"""
if self.distribution == 'uniform':
return np.arange(1, self.n_items)
elif self.distribution == 'popularity':
random_item_list = []
for dataset in self.datasets:
random_item_list.extend(dataset.inter_feat[self.iid_field].numpy())
return random_item_list
else:
raise NotImplementedError(f'Distribution [{self.distribution}] has not been implemented.')

def _get_candidates_list(self):
candidates_list = []
for dataset in self.datasets:
candidates_list.extend(dataset.inter_feat[self.iid_field].numpy())
return candidates_list

def _uni_sampling(self, sample_num):
return np.random.randint(1, self.item_num, sample_num)

def get_used_ids(self):
"""
Expand All @@ -210,15 +231,15 @@ def get_used_ids(self):
Key is phase, and value is a numpy.ndarray which index is user_id, and element is a set of item_ids.
"""
used_item_id = dict()
last = [set() for _ in range(self.n_users)]
last = [set() for _ in range(self.user_num)]
for phase, dataset in zip(self.phases, self.datasets):
cur = np.array([set(s) for s in last])
for uid, iid in zip(dataset.inter_feat[self.uid_field].numpy(), dataset.inter_feat[self.iid_field].numpy()):
cur[uid].add(iid)
last = used_item_id[phase] = cur

for used_item_set in used_item_id[self.phases[-1]]:
if len(used_item_set) + 1 == self.n_items: # [pad] is a item.
if len(used_item_set) + 1 == self.item_num: # [pad] is a item.
raise ValueError(
'Some users have interacted with all items, '
'which we can not sample negative items for them. '
Expand Down Expand Up @@ -262,7 +283,7 @@ def sample_by_user_ids(self, user_ids, item_ids, num):
return self.sample_by_key_ids(user_ids, num)
except IndexError:
for user_id in user_ids:
if user_id < 0 or user_id >= self.n_users:
if user_id < 0 or user_id >= self.user_num:
raise ValueError(f'user_id [{user_id}] not exist.')


Expand All @@ -287,17 +308,11 @@ def __init__(self, dataset, distribution='uniform'):

super().__init__(distribution=distribution)

def get_random_list(self):
"""
Returns:
numpy.ndarray or list: Random list of entity_id.
"""
if self.distribution == 'uniform':
return np.arange(1, self.entity_num)
elif self.distribution == 'popularity':
return list(self.hid_list) + list(self.tid_list)
else:
raise NotImplementedError(f'Distribution [{self.distribution}] has not been implemented.')
def _uni_sampling(self, sample_num):
return np.random.randint(1, self.entity_num, sample_num)

def _get_candidates_list(self):
return list(self.hid_list) + list(self.tid_list)

def get_used_ids(self):
"""
Expand Down Expand Up @@ -359,18 +374,24 @@ def __init__(self, phases, dataset, distribution='uniform'):
self.dataset = dataset

self.iid_field = dataset.iid_field
self.n_users = dataset.user_num
self.n_items = dataset.item_num
self.user_num = dataset.user_num
self.item_num = dataset.item_num

super().__init__(distribution=distribution)

def _uni_sampling(self, sample_num):
return np.random.randint(1, self.item_num, sample_num)

def _get_candidates_list(self):
return list(self.dataset.inter_feat[self.iid_field].numpy())

def get_random_list(self):
"""
Returns:
numpy.ndarray or list: Random list of item_id.
"""
if self.distribution == 'uniform':
return np.arange(1, self.n_items)
return np.arange(1, self.item_num)
elif self.distribution == 'popularity':
return self.dataset.inter_feat[self.iid_field].numpy()
else:
Expand All @@ -382,7 +403,7 @@ def get_used_ids(self):
numpy.ndarray: Used item_ids is the same as positive item_ids.
Index is user_id, and element is a set of item_ids.
"""
return np.array([set() for _ in range(self.n_users)])
return np.array([set() for _ in range(self.user_num)])

def sample_by_user_ids(self, user_ids, item_ids, num):
"""Sampling by user_ids.
Expand All @@ -404,7 +425,7 @@ def sample_by_user_ids(self, user_ids, item_ids, num):
return self.sample_by_key_ids(np.arange(len(user_ids)), num)
except IndexError:
for user_id in user_ids:
if user_id < 0 or user_id >= self.n_users:
if user_id < 0 or user_id >= self.user_num:
raise ValueError(f'user_id [{user_id}] not exist.')

def set_phase(self, phase):
Expand Down Expand Up @@ -435,21 +456,17 @@ def __init__(self, dataset, distribution='uniform'):
self.dataset = dataset

self.iid_field = dataset.iid_field
self.n_users = dataset.user_num
self.n_items = dataset.item_num
self.user_num = dataset.user_num
self.item_num = dataset.item_num

super().__init__(distribution=distribution)

def _uni_sampling(self, sample_num):
return np.random.randint(1, self.item_num, sample_num)

def get_used_ids(self):
pass

def get_random_list(self):
"""
Returns:
numpy.ndarray or list: Random list of item_id.
"""
return np.arange(1, self.n_items)

def sample_neg_sequence(self, pos_sequence):
"""For each moment, sampling one item from all the items except the one the user clicked on at that moment.

Expand All @@ -464,7 +481,7 @@ def sample_neg_sequence(self, pos_sequence):
value_ids = np.zeros(total_num, dtype=np.int64)
check_list = np.arange(total_num)
while len(check_list) > 0:
value_ids[check_list] = self.random_num(len(check_list))
value_ids[check_list] = self.sampling(len(check_list))
check_index = np.where(value_ids[check_list] == pos_sequence[check_list])
check_list = check_list[check_index]

Expand Down