From b288640f9473949fa0ccca833e996dbf0c33f7d5 Mon Sep 17 00:00:00 2001 From: Yupeng Hou Date: Wed, 14 Jul 2021 02:56:02 +0000 Subject: [PATCH] FEA: example for session-based rec benchmarks --- recbole/properties/dataset/url.yaml | 5 ++ run_example/session_based_rec_example.py | 83 ++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 run_example/session_based_rec_example.py diff --git a/recbole/properties/dataset/url.yaml b/recbole/properties/dataset/url.yaml index bbc5ef938..5fc38333e 100644 --- a/recbole/properties/dataset/url.yaml +++ b/recbole/properties/dataset/url.yaml @@ -82,3 +82,8 @@ yoochoose-buys-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedData yoochoose-clicks-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/YOOCHOOSE/merged/yoochoose-clicks.zip yoochoose-buys-not-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/YOOCHOOSE/not_merged/yoochoose-buys.zip yoochoose-clicks-not-merged: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/YOOCHOOSE/not_merged/yoochoose-clicks.zip + +# session-based recommendation benchmarks +diginetica-session: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/DIGINETICA/session/diginetica_session.zip +tmall-session: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Tmall/session/tmall_session.zip +nowplaying-session: https://recbole.s3-accelerate.amazonaws.com/ProcessedDatasets/Nowplaying/session/nowplaying_session.zip diff --git a/run_example/session_based_rec_example.py b/run_example/session_based_rec_example.py new file mode 100644 index 000000000..a0593e68b --- /dev/null +++ b/run_example/session_based_rec_example.py @@ -0,0 +1,83 @@ +# @Time : 2021/7/14 +# @Author : Yupeng Hou +# @Email : houyupeng@ruc.edu.cn + +""" +session-based recommendation example +======================== +Here is the sample code for running session-based recommendation benchmarks using RecBole. + +args.dataset can be one of diginetica-session/tmall-session/nowplaying-session +""" + +import argparse +from logging import getLogger + +from recbole.config import Config +from recbole.data import create_dataset +from recbole.data.utils import get_dataloader +from recbole.utils import init_logger, init_seed, get_model, get_trainer, set_color + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--model', '-m', type=str, default='GRU4Rec', help='Model for session-based rec.') + parser.add_argument('--dataset', '-d', type=str, default='diginetica-session', help='Benchmarks for session-based rec.') + parser.add_argument('--validation', action='store_true', help='Whether evaluating on validation set (split from train set), otherwise on test set.') + parser.add_argument('--valid_portion', type=float, default=0.1, help='ratio of validation set.') + return parser.parse_args() + + +if __name__ == '__main__': + args = get_args() + + # configurations initialization + config_dict = { + 'USER_ID_FIELD': 'session_id', + 'load_col': None, + 'training_neg_sample_num': 0, + 'benchmark_filename': ['train', 'test'], + 'alias_of_item_id': ['item_id_list'], + 'topk': [20], + 'metrics': ['Recall', 'MRR'], + 'valid_metric': 'MRR@20' + } + + config = Config(model=args.model, dataset=f'{args.dataset}', config_dict=config_dict) + init_seed(config['seed'], config['reproducibility']) + + # logger initialization + init_logger(config) + logger = getLogger() + + logger.info(args) + logger.info(config) + + # dataset filtering + dataset = create_dataset(config) + logger.info(dataset) + + # dataset splitting + train_dataset, test_dataset = dataset.build() + if args.validation: + train_dataset.shuffle() + new_train_dataset, new_test_dataset = train_dataset.split_by_ratio([1 - args.valid_portion, args.valid_portion]) + train_data = get_dataloader(config, 'train')(config, new_train_dataset, None, shuffle=True) + test_data = get_dataloader(config, 'test')(config, new_test_dataset, None, shuffle=False) + else: + train_data = get_dataloader(config, 'train')(config, train_dataset, None, shuffle=True) + test_data = get_dataloader(config, 'test')(config, test_dataset, None, shuffle=False) + + # model loading and initialization + model = get_model(config['model'])(config, train_data.dataset).to(config['device']) + logger.info(model) + + # trainer loading and initialization + trainer = get_trainer(config['MODEL_TYPE'], config['model'])(config, model) + + # model training and evaluation + test_score, test_result = trainer.fit( + train_data, test_data, saved=True, show_progress=config['show_progress'] + ) + + logger.info(set_color('test result', 'yellow') + f': {test_result}')