Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/tsotfsk/RecBox
Browse files Browse the repository at this point in the history
  • Loading branch information
tsotfsk committed Jul 18, 2020
2 parents 1107fb9 + a7429b2 commit 6af995d
Show file tree
Hide file tree
Showing 28 changed files with 2,835 additions and 144 deletions.
10 changes: 7 additions & 3 deletions config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ def __init__(self, config_file_name):
self.run_args = RunningConfig(config_file_name, self.cmd_args)

model_name = self.run_args['model']
model_arg_file_name = os.path.join(os.path.dirname(config_file_name), model_name + '.config')
model_dir = os.path.join(os.path.dirname(config_file_name), 'model')
model_arg_file_name = os.path.join(model_dir, model_name + '.config')
self.model_args = ModelConfig(model_arg_file_name, self.cmd_args)

dataset_name = self.run_args['dataset']
dataset_arg_file_name = os.path.join(os.path.dirname(config_file_name), dataset_name + '.config')
dataset_dir = os.path.join(os.path.dirname(config_file_name), 'dataset')
dataset_arg_file_name = os.path.join(dataset_dir, dataset_name + '.config')
self.dataset_args = DataConfig(dataset_arg_file_name, self.cmd_args)

self.device = None
Expand All @@ -61,8 +63,10 @@ def init(self):
"""
init_seed = self.run_args['seed']
gpu_id = self.run_args['gpu_id']
use_gpu = self.run_args['use_gpu']
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Get the device that run on.
# Get the device that run on.
self.device = torch.device("cuda" if torch.cuda.is_available() and use_gpu else "cpu")

random.seed(init_seed)
np.random.seed(init_seed)
Expand Down
32 changes: 14 additions & 18 deletions data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,26 @@ def __init__(self, config, dataset, batch_size):
self.dataset = dataset
self.batch_size = batch_size
self.sampler = Sampler(config, dataset)
self.pr = 0

def __iter__(self):
return self

def __next__(self):
raise NotImplementedError('Method [next] should be implemented.')

def set_batch_size(self, batch_size):
if self.pr != 0:
raise PermissionError('Cannot change dataloader\'s batch_size while iteration')
self.batch_size = batch_size

class GeneralDataLoader(AbstractDataLoader):
def __init__(self, config, dataset, batch_size, pairwise=False, shuffle=False, real_time_neg_sampling=True, neg_sample_to=None, neg_sample_by=None):
super(GeneralDataLoader, self).__init__(config, dataset, batch_size)

self.pairwise = pairwise
self.shuffle = shuffle
self.real_time_neg_sampling = real_time_neg_sampling
self.pr = 0

self.neg_sample_to = neg_sample_to
self.neg_sample_by = neg_sample_by
Expand Down Expand Up @@ -126,28 +131,19 @@ def _pre_neg_sampling(self):
iid_field: [],
label_field: []
}

uids = self.dataset.inter_feat[uid_field].to_list()
iids = self.dataset.inter_feat[iid_field].to_list()
uid2itemlist = {}
for i in range(len(uids)):
uid = uids[i]
iid = iids[i]
if uid not in uid2itemlist:
uid2itemlist[uid] = []
uid2itemlist[uid].append(iid)
grouped_uid_iid = self.dataset.inter_feat.groupby(uid_field)[iid_field]
for uid, iids in grouped_uid_iid:
uid2itemlist[uid] = iids.to_list()
for uid in uid2itemlist:
pos_num = len(uid2itemlist[uid])
if pos_num >= self.neg_sample_to:
uid2itemlist[uid] = uid2itemlist[uid][:self.neg_sample_to-1]
pos_num = self.neg_sample_to - 1
neg_num = self.neg_sample_to - pos_num
neg_item_id = self.sampler.sample_by_user_id(uid, self.neg_sample_to - pos_num)
for iid in uid2itemlist[uid]:
new_inter[uid_field].append(uid)
new_inter[iid_field].append(iid)
new_inter[label_field].append(1)
for iid in neg_item_id:
new_inter[uid_field].append(uid)
new_inter[iid_field].append(iid)
new_inter[label_field].append(0)

new_inter[uid_field].extend([uid] * self.neg_sample_to)
new_inter[iid_field].extend(uid2itemlist[uid] + neg_item_id)
new_inter[label_field].extend([1] * pos_num + [0] * neg_num)
self.dataset.inter_feat = pd.DataFrame(new_inter)
3 changes: 3 additions & 0 deletions data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def split_by_ratio(self, ratio):
def shuffle(self):
self.inter_feat = self.inter_feat.sample(frac=1).reset_index(drop=True)

def sort(self, field_name):
self.inter_feat = self.inter_feat.sort_values(by=field_name)

# TODO
def build(self,
inter_filter_lowest_val=None, inter_filter_highest_val=None,
Expand Down
File renamed without changes.
Loading

0 comments on commit 6af995d

Please sign in to comment.