Skip to content

Commit

Permalink
Merge pull request #20 from RUCAIBox/master
Browse files Browse the repository at this point in the history
update
  • Loading branch information
2017pxy authored Sep 21, 2020
2 parents 263cc95 + c673fdc commit 7481f07
Show file tree
Hide file tree
Showing 15 changed files with 438 additions and 104 deletions.
41 changes: 20 additions & 21 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from logging import getLogger
from recbox.config import Config
from recbox.data import create_dataset, data_preparation
from recbox.trainer import get_trainer
from recbox.utils import init_logger, get_model
import json
from argparse import ArgumentParser
from run_test import whole_process

config = Config('properties/overall.config')
config.init()
init_logger(config)
logger = getLogger()
parser = ArgumentParser()
parser.add_argument('--model', '-m', type=str, default='BPRMF', help='name of models')
parser.add_argument('--dataset', '-d', type=str, default='ml-100k', help='name of datasets')
parser.add_argument('--epochs', '-e', type=int, default=1, help='num of running epochs')

dataset = create_dataset(config)
logger.info(dataset)
args = parser.parse_args()

# If you want to customize the evaluation setting,
# please refer to `data_preparation()` in `data/utils.py`.
train_data, test_data, valid_data = data_preparation(config, dataset)
args_dict = {
'model': args.model,
'dataset': args.dataset,
'epochs': args.epochs
}

model = get_model(config['model'])(config, train_data).to(config['device'])
logger.info(model)
with open('presets.json', 'r', encoding='utf-8') as preset_file:
presets_dict = json.load(preset_file)

trainer = get_trainer(config['MODEL_TYPE'])(config, model)
token = '-'.join([args.model, args.dataset])
if token in presets_dict:
print('Hit preset: [{}]'.format(token))
args_dict.update(presets_dict)

# trainer.resume_checkpoint('saved/model_best.pth')
best_valid_score, _ = trainer.fit(train_data, valid_data)
result = trainer.evaluate(test_data)
logger.info(best_valid_score)
whole_process(config_dict=args_dict)
12 changes: 12 additions & 0 deletions presets.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"BPRMF-ml-1m": {
"train_batch_size": 4096,
"metrics": ["Recall", "MRR", "NDCG", "Hit", "Precision"]
},
"BPRMF-yelp": {
"train_batch_size": 4096,
"metrics": ["Recall", "MRR", "NDCG", "Hit", "Precision"],
"min_user_inter_num": 25,
"max_user_inter_num": 25
}
}
25 changes: 25 additions & 0 deletions properties/dataset/lfm-1b2013.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
[data]

########define the UIRT columns
USER_ID_FIELD='user_id'
ITEM_ID_FIELD='item_id'
NEG_PREFIX='neg_'
LABEL_FIELD='label'
RATING_FIELD='rating'
HEAD_ENTITY_ID_FIELD='head_id'
TAIL_ENTITY_ID_FIELD='tail_id'
RELATION_ID_FIELD='relation_id'
ENTITY_ID_FIELD='entity_id'

#########select load columns
load_col={'inter': ['user_id', 'item_id', 'rating'], 'kg': ['head_id', 'relation_id', 'tail_id'], 'link': ['item_id', 'entity_id']}

########data separator
field_separator='\t'
seq_separator=' '

########data filter
lowest_val={'rating':10}
drop_filter_field=True
min_user_inter_num=10
min_item_inter_num=10
8 changes: 8 additions & 0 deletions properties/model/KGAT.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[model]

embedding_size=64
kg_embedding_size=64
layers=[64]
mess_dropout=[0.1]
reg_weight=1e-5
kg_weight=0.5
11 changes: 8 additions & 3 deletions recbox/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import random
import sys
from logging import getLogger

import numpy as np
import torch
Expand Down Expand Up @@ -91,16 +92,20 @@ def init(self):
torch.backends.cudnn.deterministic = True

def _read_cmd_line(self):

unrecognized_args = []
if "ipykernel_launcher" not in sys.argv[0]:
for arg in sys.argv[1:]:
if not arg.startswith("--"):
raise SyntaxError("Commend arg must start with '--', but '%s' is not!" % arg)
if not arg.startswith("--") or len(arg[2:].split("=")) != 2:
unrecognized_args.append(arg)
continue
cmd_arg_name, cmd_arg_value = arg[2:].split("=")
if cmd_arg_name in self.cmd_args_dict and cmd_arg_value != self.cmd_args_dict[cmd_arg_name]:
raise SyntaxError("There are duplicate commend arg '%s' with different value!" % arg)
else:
self.cmd_args_dict[cmd_arg_name] = cmd_arg_value
if len(unrecognized_args) > 0:
logger = getLogger()
logger.warning('command line args [{}] will not be used in RecBox'.format(' '.join(unrecognized_args)))

def _read_config_dict(self):
for dict_arg_name in self.config_dict:
Expand Down
61 changes: 24 additions & 37 deletions recbox/data/dataloader/knowledge_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,38 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE
# @Time : 2020/9/18, 2020/9/17, 2020/8/31
# @Time : 2020/9/18, 2020/9/21, 2020/8/31
# @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li
# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com


from recbox.data.dataloader.abstract_dataloader import AbstractDataLoader
from recbox.data.dataloader.general_dataloader import GeneralNegSampleDataLoader
from recbox.data.dataloader.neg_sample_mixin import NegSampleMixin
from recbox.data.dataloader import AbstractDataLoader, GeneralNegSampleDataLoader
from recbox.utils import InputType, KGDataLoaderState


class KGDataLoader(NegSampleMixin, AbstractDataLoader):
class KGDataLoader(AbstractDataLoader):

def __init__(self, config, dataset, sampler, neg_sample_args,
def __init__(self, config, dataset, sampler,
batch_size=1, dl_format=InputType.PAIRWISE, shuffle=False):
if neg_sample_args['strategy'] != 'by':
raise ValueError('neg_sample strategy in KnowledgeBasedDataLoader() should be `by`')
if dl_format != InputType.PAIRWISE or neg_sample_args['by'] != 1:
raise ValueError('kg based dataloader must be pairwise and can only neg sample by 1')
if shuffle is False:
raise ValueError('kg based dataloader must shuffle the data')
self.sampler = sampler
self.neg_sample_num = 1

self.batch_size = batch_size
self.neg_sample_by = neg_sample_args['by']
self.times = 1

neg_prefix = config['NEG_PREFIX']
tid_field = config['TAIL_ENTITY_ID_FIELD']
self.neg_prefix = config['NEG_PREFIX']
self.hid_field = dataset.head_entity_field
self.tid_field = dataset.tail_entity_field

# kg negative cols
neg_kg_col = neg_prefix + tid_field
dataset.copy_field_property(neg_kg_col, tid_field)
self.neg_tid_field = self.neg_prefix + self.tid_field
dataset.copy_field_property(self.neg_tid_field, self.tid_field)

super().__init__(config, dataset, sampler, neg_sample_args,
super().__init__(config, dataset,
batch_size=batch_size, dl_format=dl_format, shuffle=shuffle)

def setup(self):
if self.shuffle is False:
self.shuffle = True
self.logger.warning('kg based dataloader must shuffle the data')

@property
def pr_end(self):
# TODO 这个地方应该是取kg_data的len
Expand All @@ -61,21 +57,11 @@ def data_preprocess(self):
self.dataset.kg_feat = self._neg_sampling(self.dataset.kg_feat)

def _neg_sampling(self, kg_feat):
hid_field = self.config['HEAD_ENTITY_ID_FIELD']
tid_field = self.config['TAIL_ENTITY_ID_FIELD']
hids = kg_feat[hid_field].to_list()
neg_tids = self.sampler.sample_by_entity_ids(hids, self.neg_sample_by)
return self._neg_sample_by_pair_wise_sampling(tid_field, neg_tids, kg_feat)

def _neg_sample_by_pair_wise_sampling(self, tid_field, neg_tids, kg_feat):
neg_prefix = self.config['NEG_PREFIX']
neg_tail_entity_id = neg_prefix + tid_field
kg_feat.insert(len(kg_feat.columns), neg_tail_entity_id, neg_tids)
hids = kg_feat[self.hid_field].to_list()
neg_tids = self.sampler.sample_by_entity_ids(hids, self.neg_sample_num)
kg_feat.insert(len(kg_feat.columns), self.neg_tid_field, neg_tids)
return kg_feat

def _batch_size_adaptation(self):
self.step = self.batch_size


class KnowledgeBasedDataLoader(AbstractDataLoader):

Expand All @@ -89,7 +75,7 @@ def __init__(self, config, dataset, sampler, kg_sampler, neg_sample_args,
shuffle=shuffle)

# using kg_sampler and dl_format is pairwise
self.kg_dataloader = KGDataLoader(config, dataset, kg_sampler, neg_sample_args,
self.kg_dataloader = KGDataLoader(config, dataset, kg_sampler,
batch_size=batch_size, dl_format=InputType.PAIRWISE, shuffle=shuffle)

self.main_dataloader = self.general_dataloader
Expand Down Expand Up @@ -120,10 +106,11 @@ def _shuffle(self):

def __next__(self):
if self.pr >= self.pr_end:
self.pr = 0
# After the rec data ends, the kg data pointer needs to be cleared to zero
if self.state == KGDataLoaderState.RSKG:
self.general_dataloader.pr = 0
self.kg_dataloader.pr = 0
else:
self.pr = 0
raise StopIteration()
return self._next_batch_data()

Expand Down
4 changes: 2 additions & 2 deletions recbox/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE:
# @Time : 2020/9/15, 2020/9/15, 2020/9/17
# @Time : 2020/9/17, 2020/9/15, 2020/9/17
# @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen
# @Email : houyupeng@ruc.edu.cn, panxy@ruc.edu.cn, chenyushuo@ruc.edu.cn

Expand Down Expand Up @@ -673,7 +673,7 @@ def __repr__(self):
return self.__str__()

def __str__(self):
info = []
info = [self.dataset_name]
if self.uid_field:
info.extend(['The number of users: {}'.format(self.user_num),
'Average actions of users: {}'.format(self.avg_actions_of_users)])
Expand Down
2 changes: 1 addition & 1 deletion recbox/model/general_recommender/ngcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_norm_adj_mat(self):
col = L.col
i = torch.LongTensor([row, col])
data = torch.FloatTensor(L.data)
SparseL = torch.sparse.FloatTensor(i, data)
SparseL = torch.sparse.FloatTensor(i, data, torch.Size(L.shape))
return SparseL

def get_eye_mat(self):
Expand Down
Loading

0 comments on commit 7481f07

Please sign in to comment.