Skip to content

Commit

Permalink
FEA: example for session-based rec benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
hyp1231 committed Jul 14, 2021
1 parent e9be77e commit b288640
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
5 changes: 5 additions & 0 deletions recbole/properties/dataset/url.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,8 @@ yoochoose-buys-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedData
yoochoose-clicks-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/YOOCHOOSE/merged/yoochoose-clicks.zip
yoochoose-buys-not-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/YOOCHOOSE/not_merged/yoochoose-buys.zip
yoochoose-clicks-not-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/YOOCHOOSE/not_merged/yoochoose-clicks.zip

# session-based recommendation benchmarks
diginetica-session: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/DIGINETICA/session/diginetica_session.zip
tmall-session: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Tmall/session/tmall_session.zip
nowplaying-session: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Nowplaying/session/nowplaying_session.zip
83 changes: 83 additions & 0 deletions run_example/session_based_rec_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# @Time : 2021/7/14
# @Author : Yupeng Hou
# @Email : houyupeng@ruc.edu.cn

"""
session-based recommendation example
========================
Here is the sample code for running session-based recommendation benchmarks using RecBole.
args.dataset can be one of diginetica-session/tmall-session/nowplaying-session
"""

import argparse
from logging import getLogger

from recbole.config import Config
from recbole.data import create_dataset
from recbole.data.utils import get_dataloader
from recbole.utils import init_logger, init_seed, get_model, get_trainer, set_color


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', type=str, default='GRU4Rec', help='Model for session-based rec.')
parser.add_argument('--dataset', '-d', type=str, default='diginetica-session', help='Benchmarks for session-based rec.')
parser.add_argument('--validation', action='store_true', help='Whether evaluating on validation set (split from train set), otherwise on test set.')
parser.add_argument('--valid_portion', type=float, default=0.1, help='ratio of validation set.')
return parser.parse_args()


if __name__ == '__main__':
args = get_args()

# configurations initialization
config_dict = {
'USER_ID_FIELD': 'session_id',
'load_col': None,
'training_neg_sample_num': 0,
'benchmark_filename': ['train', 'test'],
'alias_of_item_id': ['item_id_list'],
'topk': [20],
'metrics': ['Recall', 'MRR'],
'valid_metric': 'MRR@20'
}

config = Config(model=args.model, dataset=f'{args.dataset}', config_dict=config_dict)
init_seed(config['seed'], config['reproducibility'])

# logger initialization
init_logger(config)
logger = getLogger()

logger.info(args)
logger.info(config)

# dataset filtering
dataset = create_dataset(config)
logger.info(dataset)

# dataset splitting
train_dataset, test_dataset = dataset.build()
if args.validation:
train_dataset.shuffle()
new_train_dataset, new_test_dataset = train_dataset.split_by_ratio([1 - args.valid_portion, args.valid_portion])
train_data = get_dataloader(config, 'train')(config, new_train_dataset, None, shuffle=True)
test_data = get_dataloader(config, 'test')(config, new_test_dataset, None, shuffle=False)
else:
train_data = get_dataloader(config, 'train')(config, train_dataset, None, shuffle=True)
test_data = get_dataloader(config, 'test')(config, test_dataset, None, shuffle=False)

# model loading and initialization
model = get_model(config['model'])(config, train_data.dataset).to(config['device'])
logger.info(model)

# trainer loading and initialization
trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model)

# model training and evaluation
test_score, test_result = trainer.fit(
train_data, test_data, saved=True, show_progress=config['show_progress']
)

logger.info(set_color('test result', 'yellow') + f': {test_result}')

0 comments on commit b288640

Please sign in to comment.