Skip to content

Commit

Permalink
Merge pull request #877 from hyp1231/data
Browse files Browse the repository at this point in the history
FEA: RS, benchmark for seq-rec & FIX: update argument_list
  • Loading branch information
chenyushuo authored Jul 12, 2021
2 parents 1090915 + a78d7bb commit 0ec2332
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 32 deletions.
4 changes: 2 additions & 2 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : houyupeng@ruc.edu.cn

# UPDATE:
# @Time : 2020/10/28 2021/7/1, 2020/11/10
# @Time : 2021/7/11 2021/7/1, 2020/11/10
# @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen
# @Email : houyupeng@ruc.edu.cn, xy_pan@foxmail.com, chenyushuo@ruc.edu.cn

Expand Down Expand Up @@ -942,7 +942,7 @@ def _concat_remaped_tokens(self, remap_list):
if ftype == FeatureType.TOKEN:
tokens.append(feat[field].values)
elif ftype == FeatureType.TOKEN_SEQ:
tokens.append(feat[field].agg(np.concatenate))
tokens.append(feat[field].reset_index(drop=True).agg(np.concatenate))
split_point = np.cumsum(list(map(len, tokens)))[:-1]
tokens = np.concatenate(tokens)
return tokens, split_point
Expand Down
13 changes: 12 additions & 1 deletion recbole/data/dataset/sequential_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# @Email : chenyushuo@ruc.edu.cn

# UPDATE:
# @Time : 2020/9/16, 2021/7/1, 2021/7/9
# @Time : 2020/9/16, 2021/7/1, 2021/7/11
# @Author : Yushuo Chen, Xingyu Pan, Yupeng Hou
# @Email : chenyushuo@ruc.edu.cn, xy_pan@foxmail.com, houyupeng@ruc.edu.cn

Expand Down Expand Up @@ -34,6 +34,8 @@ def __init__(self, config):
self.max_item_list_len = config['MAX_ITEM_LIST_LENGTH']
self.item_list_length_field = config['ITEM_LIST_LENGTH_FIELD']
super().__init__(config)
if config['benchmark_filename'] is not None:
self._benchmark_presets()

def _change_feat_format(self):
"""Change feat format from :class:`pandas.DataFrame` to :class:`Interaction`,
Expand Down Expand Up @@ -134,6 +136,15 @@ def data_augmentation(self):
new_data.update(Interaction(new_dict))
self.inter_feat = new_data

def _benchmark_presets(self):
list_suffix = self.config['LIST_SUFFIX']
for field in self.inter_feat:
if field + list_suffix in self.inter_feat:
list_field = field + list_suffix
setattr(self, f'{field}_list_field', list_field)
self.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1)
self.inter_feat[self.item_list_length_field] = self.inter_feat[self.item_id_list_field].agg(len)

def inter_matrix(self, form='coo', value_field=None):
"""Get sparse matrix that describe interactions between user_id and item_id.
Sparse matrix has shape (user_num, item_num).
Expand Down
6 changes: 1 addition & 5 deletions recbole/properties/dataset/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ additional_feat_suffix: ~
rm_dup_inter: ~
val_interval: ~
filter_inter_by_user_or_item: True
item_inter_num_interval: "[0,inf)"
user_inter_num_interval: "[0,inf)"
item_inter_num_interval: "[0,inf)"

# Preprocessing
alias_of_user_id: ~
Expand All @@ -52,9 +52,5 @@ TAIL_ENTITY_ID_FIELD: tail_id
RELATION_ID_FIELD: relation_id
ENTITY_ID_FIELD: entity_id

# Social Model Needed
SOURCE_ID_FIELD: source_id
TARGET_ID_FIELD: target_id

# Benchmark .inter
benchmark_filename: ~
1 change: 0 additions & 1 deletion recbole/properties/overall.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ eval_args:
group_by: user
order: RO
mode: full
real_time_process: False
metrics: ["Recall","MRR","NDCG","Hit","Precision"]
topk: [10]
valid_metric: MRR@10
Expand Down
2 changes: 1 addition & 1 deletion recbole/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def get_used_ids(self):
raise ValueError(
'Some users have interacted with all items, '
'which we can not sample negative items for them. '
'Please set `max_user_inter_num` to filter those users.'
'Please set `user_inter_num_interval` to filter those users.'
)
return used_item_id

Expand Down
13 changes: 6 additions & 7 deletions recbole/utils/argument_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
'reproducibility',
'state',
'data_path',
'benchmark_filename',
'show_progress',
'config_file'
]

training_arguments = [
Expand All @@ -27,11 +29,8 @@
]

evaluation_arguments = [
'eval_setting',
'group_by_user',
'split_ratio', 'leave_one_num',
'real_time_process',
'metrics', 'topk', 'valid_metric',
'eval_args',
'metrics', 'topk', 'valid_metric', 'valid_metric_bigger',
'eval_batch_size',
'metric_decimal_place'
]
Expand All @@ -45,8 +44,8 @@
'ITEM_LIST_LENGTH_FIELD', 'LIST_SUFFIX', 'MAX_ITEM_LIST_LENGTH', 'POSITION_FIELD',
'HEAD_ENTITY_ID_FIELD', 'TAIL_ENTITY_ID_FIELD', 'RELATION_ID_FIELD', 'ENTITY_ID_FIELD',
'load_col', 'unload_col', 'unused_col', 'additional_feat_suffix',
'user_inter_num_interval', 'item_inter_num_interval ',
'val_interval',
'filter_inter_by_user_or_item', 'rm_dup_inter',
'val_interval', 'user_inter_num_interval', 'item_inter_num_interval',
'alias_of_user_id', 'alias_of_item_id', 'alias_of_entity_id', 'alias_of_relation_id',
'preload_weight',
'normalize_field', 'normalize_all'
Expand Down
1 change: 0 additions & 1 deletion tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def test_default_settings(self):
self.assertIsInstance(config['checkpoint_dir'], str)

self.assertIsInstance(config['eval_args'], dict)
self.assertIsInstance(config['real_time_process'], bool)
self.assertIsInstance(config['metrics'], list)
self.assertIsInstance(config['topk'], list)
self.assertIsInstance(config['valid_metric'], str)
Expand Down
3 changes: 0 additions & 3 deletions tests/config/test_overall.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def test_checkpoint_dir(self):
def test_eval_batch_size(self):
self.assertTrue(run_parms({'eval_batch_size': [1, 100]}))

def test_real_time_process(self):
self.assertTrue(run_parms({'real_time_process':[False, True]}))

def test_topk(self):
settings = {
'metrics': ["Recall", "MRR", "NDCG", "Hit", "Precision"],
Expand Down
5 changes: 5 additions & 0 deletions tests/data/seq_benchmark/seq_benchmark.test.inter
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
user_id:token item_id:token timestamp:float item_id_list:token_seq timestamp_list:float_seq
1 7 7 1 2 3 4 5 6 1 2 3 4 5 6
1 8 8 1 2 3 4 5 6 7 1 2 3 4 5 6 7
2 8 8 4 5 6 7 4 5 6 7
3 6 6 4 5 4 5
8 changes: 8 additions & 0 deletions tests/data/seq_benchmark/seq_benchmark.train.inter
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
user_id:token item_id:token timestamp:float item_id_list:token_seq timestamp_list:float_seq
1 2 2 1 1
1 3 3 1 2 1 2
1 4 4 1 2 3 1 2 3
4 4 4 3 3
2 5 5 4 4
2 6 6 4 5 4 5
3 5 5 4 4
4 changes: 4 additions & 0 deletions tests/data/seq_benchmark/seq_benchmark.valid.inter
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
user_id:token item_id:token timestamp:float item_id_list:token_seq timestamp_list:float_seq
1 5 5 1 2 3 4 1 2 3 4
1 6 6 1 2 3 4 5 1 2 3 4 5
2 7 7 4 5 6 4 5 6
84 changes: 83 additions & 1 deletion tests/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# @Email : chenyushuo@ruc.edu.cn

# UPDATE
# @Time : 2020/1/3, 2021/7/1, 2021/7/9
# @Time : 2020/1/3, 2021/7/1, 2021/7/11
# @Author : Yushuo Chen, Xingyu Pan, Yupeng Hou
# @email : chenyushuo@ruc.edu.cn, xy_pan@foxmail.com, houyupeng@ruc.edu.cn

Expand Down Expand Up @@ -612,6 +612,88 @@ def test_seq_leave_one_out(self):
[0., 0., 0., 0., 0., 0., 0., 0., 0.]
]).all()

def test_seq_split_by_ratio(self):
config_dict = {
'model': 'GRU4Rec',
'dataset': 'seq_dataset',
'data_path': current_path,
'load_col': None,
'training_neg_sample_num': 0,
'eval_args': {
'split': {'RS': [0.3, 0.3, 0.4]},
'order': 'TO'
}
}
train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict)
assert (train_dataset.inter_feat[train_dataset.uid_field].numpy() == [1, 1, 1, 4, 2, 2, 3]).all()
assert (train_dataset.inter_feat[train_dataset.item_id_list_field][:,:3].numpy() == [
[1, 0, 0],
[1, 2, 0],
[1, 2, 3],
[3, 0, 0],
[4, 0, 0],
[4, 5, 0],
[4, 0, 0]]).all()
assert (train_dataset.inter_feat[train_dataset.iid_field].numpy() == [2, 3, 4, 4, 5, 6, 5]).all()
assert (train_dataset.inter_feat[train_dataset.item_list_length_field].numpy() == [1, 2, 3, 1, 1, 2, 1]).all()

assert (valid_dataset.inter_feat[valid_dataset.uid_field].numpy() == [1, 1, 2]).all()
assert (valid_dataset.inter_feat[valid_dataset.item_id_list_field][:,:5].numpy() == [
[1, 2, 3, 4, 0],
[1, 2, 3, 4, 5],
[4, 5, 6, 0, 0]]).all()
assert (valid_dataset.inter_feat[valid_dataset.iid_field].numpy() == [5, 6, 7]).all()
assert (valid_dataset.inter_feat[valid_dataset.item_list_length_field].numpy() == [4, 5, 3]).all()

assert (test_dataset.inter_feat[test_dataset.uid_field].numpy() == [1, 1, 2, 3]).all()
assert (test_dataset.inter_feat[test_dataset.item_id_list_field][:,:7].numpy() == [
[1, 2, 3, 4, 5, 6, 0],
[1, 2, 3, 4, 5, 6, 7],
[4, 5, 6, 7, 0, 0, 0],
[4, 5, 0, 0, 0, 0, 0]]).all()
assert (test_dataset.inter_feat[test_dataset.iid_field].numpy() == [7, 8, 8, 6]).all()
assert (test_dataset.inter_feat[test_dataset.item_list_length_field].numpy() == [6, 7, 4, 2]).all()

def test_seq_benchmark(self):
config_dict = {
'model': 'GRU4Rec',
'dataset': 'seq_benchmark',
'data_path': current_path,
'load_col': None,
'training_neg_sample_num': 0,
'benchmark_filename': ['train', 'valid', 'test'],
'alias_of_item_id': ['item_id_list']
}
train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict)
assert (train_dataset.inter_feat[train_dataset.uid_field].numpy() == [1, 1, 1, 2, 3, 3, 4]).all()
assert (train_dataset.inter_feat[train_dataset.item_id_list_field][:,:3].numpy() == [
[8, 0, 0],
[8, 1, 0],
[8, 1, 2],
[2, 0, 0],
[3, 0, 0],
[3, 4, 0],
[3, 0, 0]]).all()
assert (train_dataset.inter_feat[train_dataset.iid_field].numpy() == [1, 2, 3, 3, 4, 5, 4]).all()
assert (train_dataset.inter_feat[train_dataset.item_list_length_field].numpy() == [1, 2, 3, 1, 1, 2, 1]).all()

assert (valid_dataset.inter_feat[valid_dataset.uid_field].numpy() == [1, 1, 3]).all()
assert (valid_dataset.inter_feat[valid_dataset.item_id_list_field][:,:5].numpy() == [
[8, 1, 2, 3, 0],
[8, 1, 2, 3, 4],
[3, 4, 5, 0, 0]]).all()
assert (valid_dataset.inter_feat[valid_dataset.iid_field].numpy() == [4, 5, 6]).all()
assert (valid_dataset.inter_feat[valid_dataset.item_list_length_field].numpy() == [4, 5, 3]).all()

assert (test_dataset.inter_feat[test_dataset.uid_field].numpy() == [1, 1, 3, 4]).all()
assert (test_dataset.inter_feat[test_dataset.item_id_list_field][:,:7].numpy() == [
[8, 1, 2, 3, 4, 5, 0],
[8, 1, 2, 3, 4, 5, 6],
[3, 4, 5, 6, 0, 0, 0],
[3, 4, 0, 0, 0, 0, 0]]).all()
assert (test_dataset.inter_feat[test_dataset.iid_field].numpy() == [6, 7, 7, 5]).all()
assert (test_dataset.inter_feat[test_dataset.item_list_length_field].numpy() == [6, 7, 4, 2]).all()


class TestKGDataset:
def test_kg_remap_id(self):
Expand Down
10 changes: 0 additions & 10 deletions tests/model/test_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,6 @@ load_col:

unload_col: ~

# Filtering
max_user_inter_num: ~
min_user_inter_num: ~
max_item_inter_num: ~
min_item_inter_num: ~
lowest_val: ~
highest_val: ~
equal_val: ~
not_equal_val: ~

# Preprocessing
alias_of_user_id: ~
alias_of_item_id: ~
Expand Down

0 comments on commit 0ec2332

Please sign in to comment.