Skip to content

Commit

Permalink
Merge pull request #885 from hyp1231/data
Browse files Browse the repository at this point in the history
FEA: example for running session-based rec benchmarks
  • Loading branch information
hyp1231 authored Jul 14, 2021
2 parents 99bb56d + b288640 commit ea6a175
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 6 deletions.
7 changes: 3 additions & 4 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE:
# @Time : 2021/7/13 2021/7/1, 2020/11/10
# @Time : 2021/7/14 2021/7/1, 2020/11/10
# @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen
# @Email : houyupeng@ruc.edu.cn, xy_pan@foxmail.com, chenyushuo@ruc.edu.cn

Expand Down Expand Up @@ -240,6 +240,8 @@ def _load_data(self, token, dataset_path):
token (str): dataset name.
dataset_path (str): path of dataset dir.
"""
if not os.path.exists(dataset_path):
self._download()
self._load_inter_feat(token, dataset_path)
self.user_feat = self._load_user_or_item_feat(token, dataset_path, FeatureSource.USER, 'uid_field')
self.item_feat = self._load_user_or_item_feat(token, dataset_path, FeatureSource.ITEM, 'iid_field')
Expand All @@ -259,9 +261,6 @@ def _load_inter_feat(self, token, dataset_path):
dataset_path (str): path of dataset dir.
"""
if self.benchmark_filename_list is None:
if not os.path.exists(dataset_path):
self._download()

inter_feat_path = os.path.join(dataset_path, f'{token}.inter')
if not os.path.isfile(inter_feat_path):
raise ValueError(f'File {inter_feat_path} not exist.')
Expand Down
5 changes: 5 additions & 0 deletions recbole/properties/dataset/url.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,8 @@ yoochoose-buys-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedData
yoochoose-clicks-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/YOOCHOOSE/merged/yoochoose-clicks.zip
yoochoose-buys-not-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/YOOCHOOSE/not_merged/yoochoose-buys.zip
yoochoose-clicks-not-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/YOOCHOOSE/not_merged/yoochoose-clicks.zip

# session-based recommendation benchmarks
diginetica-session: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/DIGINETICA/session/diginetica_session.zip
tmall-session: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Tmall/session/tmall_session.zip
nowplaying-session: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Nowplaying/session/nowplaying_session.zip
6 changes: 4 additions & 2 deletions recbole/utils/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ def rename_atomic_files(folder, old_name, new_name):
files = os.listdir(folder)
for f in files:
base, suf = os.path.splitext(f)
if base != old_name:
if not old_name in base:
continue
assert suf in {'.inter', '.user', '.item'}
os.rename(os.path.join(folder, f), os.path.join(folder, new_name + suf))
os.rename(
os.path.join(folder, f),
os.path.join(folder, base.replace(old_name, new_name) + suf))

if __name__ == '__main__':
pass
83 changes: 83 additions & 0 deletions run_example/session_based_rec_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# @Time : 2021/7/14
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn

"""
session-based recommendation example
========================
Here is the sample code for running session-based recommendation benchmarks using RecBole.
args.dataset can be one of diginetica-session/tmall-session/nowplaying-session
"""

import argparse
from logging import getLogger

from recbole.config import Config
from recbole.data import create_dataset
from recbole.data.utils import get_dataloader
from recbole.utils import init_logger, init_seed, get_model, get_trainer, set_color


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', type=str, default='GRU4Rec', help='Model for session-based rec.')
parser.add_argument('--dataset', '-d', type=str, default='diginetica-session', help='Benchmarks for session-based rec.')
parser.add_argument('--validation', action='store_true', help='Whether evaluating on validation set (split from train set), otherwise on test set.')
parser.add_argument('--valid_portion', type=float, default=0.1, help='ratio of validation set.')
return parser.parse_args()


if __name__ == '__main__':
args = get_args()

# configurations initialization
config_dict = {
'USER_ID_FIELD': 'session_id',
'load_col': None,
'training_neg_sample_num': 0,
'benchmark_filename': ['train', 'test'],
'alias_of_item_id': ['item_id_list'],
'topk': [20],
'metrics': ['Recall', 'MRR'],
'valid_metric': 'MRR@20'
}

config = Config(model=args.model, dataset=f'{args.dataset}', config_dict=config_dict)
init_seed(config['seed'], config['reproducibility'])

# logger initialization
init_logger(config)
logger = getLogger()

logger.info(args)
logger.info(config)

# dataset filtering
dataset = create_dataset(config)
logger.info(dataset)

# dataset splitting
train_dataset, test_dataset = dataset.build()
if args.validation:
train_dataset.shuffle()
new_train_dataset, new_test_dataset = train_dataset.split_by_ratio([1 - args.valid_portion, args.valid_portion])
train_data = get_dataloader(config, 'train')(config, new_train_dataset, None, shuffle=True)
test_data = get_dataloader(config, 'test')(config, new_test_dataset, None, shuffle=False)
else:
train_data = get_dataloader(config, 'train')(config, train_dataset, None, shuffle=True)
test_data = get_dataloader(config, 'test')(config, test_dataset, None, shuffle=False)

# model loading and initialization
model = get_model(config['model'])(config, train_data.dataset).to(config['device'])
logger.info(model)

# trainer loading and initialization
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)

# model training and evaluation
test_score, test_result = trainer.fit(
train_data, test_data, saved=True, show_progress=config['show_progress']
)

logger.info(set_color('test result', 'yellow') + f': {test_result}')

0 comments on commit ea6a175

Please sign in to comment.