Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: refactor the dataloader #1

Merged
merged 78 commits into from
Jul 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
1c3d9f4
Merge pull request #39 from RUCAIBox/master
2017pxy Jun 16, 2021
71767d1
Merge branch 'RUCAIBox:data' into data
2017pxy Jun 23, 2021
2b75dcc
FEA: data filter by value and number of interactions
ChangxinTian Jun 24, 2021
ccd4d32
FEA: data filter by value and number of interactions
ChangxinTian Jun 24, 2021
cbbef35
FIX: update the parameters of data filter
ChangxinTian Jun 25, 2021
ddcb704
FIX: doc for parameters of data filter.
ChangxinTian Jun 25, 2021
29cb60b
FIX: code format and funcname.
ChangxinTian Jun 26, 2021
19de37f
Merge branch 'RUCAIBox:data' into data
2017pxy Jun 27, 2021
caa9186
FIX: fix MacridVAE for issue #859
2017pxy Jun 30, 2021
f64f631
FIX: fix doc and optimize code of data filter.
ChangxinTian Jun 30, 2021
964f045
FIX: code format.
ChangxinTian Jun 30, 2021
17fe101
Merge pull request #854 from ChangxinTian/data
chenyushuo Jun 30, 2021
7f4bc57
FIX: remove eval_settings
2017pxy Jul 1, 2021
7de7653
Update recbole/config/configurator.py
2017pxy Jul 5, 2021
fcfa07b
Update recbole/data/utils.py
2017pxy Jul 5, 2021
f5707b5
Update recbole/data/dataset/dataset.py
2017pxy Jul 5, 2021
2d9ee7e
Merge pull request #862 from 2017pxy/data
chenyushuo Jul 5, 2021
9b30340
REFACTOR: refactor `remap_all`
chenyushuo Jul 5, 2021
289061a
REFACTOR: add doc and unittest to `remap_all`
chenyushuo Jul 5, 2021
ae01bf8
Merge branch 'data' of github.com:RUCAIBox/RecBole into data
chenyushuo Jul 5, 2021
3502e87
Merge pull request #867 from guijiql/evaluator
linzihan-backforward Jul 6, 2021
3753309
Merge pull request #17 from RUCAIBox/evaluator
linzihan-backforward Jul 6, 2021
5fb7dbb
REFACTOR: add `alias_of_relation_id` and update doc.
chenyushuo Jul 6, 2021
1e75c37
Merge pull request #868 from chenyushuo/data
2017pxy Jul 6, 2021
7579725
Merge pull request #861 from 2017pxy/master
2017pxy Jul 7, 2021
6a30e7d
FEA: tensorboard.
ChangxinTian Jul 7, 2021
ae4e71d
Merge branch 'tmp' into data
ChangxinTian Jul 7, 2021
c8fdea1
FIX: remove the ANSI escape sequences from the log file.
ChangxinTian Jul 7, 2021
1d3ca36
FEA: update long-tail metric
linzihan-backforward Jul 8, 2021
246dcfa
FIX: bugs of datasets auto-downloading
hyp1231 Jul 8, 2021
0167ac8
FEA: add config 'augmentation'
hyp1231 Jul 8, 2021
deea3a4
Merge pull request #870 from hyp1231/data
chenyushuo Jul 8, 2021
4fc613b
REFACTOR: real augmentation for seq rec
hyp1231 Jul 8, 2021
563043c
FIX: bugs in negative sample seq dataloader
hyp1231 Jul 9, 2021
09e967a
FIX: comments for seq dataset
hyp1231 Jul 9, 2021
0fd6fa1
FIX: Delete outdated parameters in the parameter config.
ChangxinTian Jul 9, 2021
d097d4e
REFACTOR: DIEN & DIN's dataset & dataloader
hyp1231 Jul 9, 2021
9c2cbae
FEA: inter_matrix for seq dataset
hyp1231 Jul 9, 2021
894fd33
FIX: test for seq loo
hyp1231 Jul 9, 2021
dc4fc14
FIX: update meta-data for test_dataset.py
hyp1231 Jul 9, 2021
57bd042
FIX: double regularization warning.
ChangxinTian Jul 9, 2021
7e22666
FIX: add tensorboard dependency.
ChangxinTian Jul 9, 2021
829319b
FIX: typo in formula
linzihan-backforward Jul 9, 2021
27d5bfd
FIX: remove arg 'augmentation'
hyp1231 Jul 9, 2021
d00852d
Update configurator.py
hyp1231 Jul 9, 2021
5cda01a
FIX: typo
linzihan-backforward Jul 10, 2021
1cfa3b2
Merge pull request #869 from linzihan-backforward/evaluator
linzihan-backforward Jul 10, 2021
497327d
FIX: remove arg augmentation
hyp1231 Jul 10, 2021
5a14b58
Merge branch 'data' of github.com:hyp1231/RecBole into data
hyp1231 Jul 10, 2021
99c8d59
Update argument_list.py
hyp1231 Jul 10, 2021
36bffcf
Merge pull request #873 from hyp1231/data
2017pxy Jul 10, 2021
006b013
REFACTOR: refactor dataloaders
chenyushuo Jul 10, 2021
ed9bad3
Merge branch 'data' of github.com:RUCAIBox/RecBole into data
chenyushuo Jul 10, 2021
8cb73fe
REFACTOR: remove other dataloaders and fix repeatable sampler
chenyushuo Jul 10, 2021
c02ab8a
Merge pull request #876 from chenyushuo/data
hyp1231 Jul 11, 2021
939eb88
Merge pull request #874 from ChangxinTian/master
hyp1231 Jul 11, 2021
0a36bb1
FEA: add test for seq split by ratio
hyp1231 Jul 11, 2021
32c4315
FEA: add benchmark loading for seq dataset
hyp1231 Jul 11, 2021
2f449e7
FIX: remove plot_train_loss.
ChangxinTian Jul 11, 2021
186f58c
FIX: remove arg 'real_time_process'
hyp1231 Jul 11, 2021
4452f29
FIX: remove useless args & update arguments_list
hyp1231 Jul 11, 2021
39d1960
FIX: update meta-data
hyp1231 Jul 11, 2021
1090915
Merge pull request #875 from ChangxinTian/data
2017pxy Jul 11, 2021
237e6db
Merge remote-tracking branch 'origin/evaluator' into data
2017pxy Jul 11, 2021
5dc22d0
Merge branch 'data' into evaluator
2017pxy Jul 11, 2021
a78d7bb
Merge branch 'data' into data
hyp1231 Jul 11, 2021
e80de05
Merge branch 'evaluator' into evaluator
2017pxy Jul 11, 2021
c463c64
Merge pull request #40 from RUCAIBox/evaluator
2017pxy Jul 11, 2021
4971123
fix the bug in eval
2017pxy Jul 11, 2021
5835470
Update configurator.py
2017pxy Jul 11, 2021
c01c51f
Merge pull request #879 from 2017pxy/evaluator
2017pxy Jul 11, 2021
0ec2332
Merge pull request #877 from hyp1231/data
chenyushuo Jul 12, 2021
976dfb2
Merge pull request #881 from RUCAIBox/data
hyp1231 Jul 12, 2021
a3596a2
Rebuild: rebuild the dataloader
2017pxy Jul 15, 2021
3681b2c
Merge branch 'evaluator' of https://github.com/2017pxy/RecBole into e…
2017pxy Jul 15, 2021
9753e4e
Refactor: refactor the dataloader
2017pxy Jul 15, 2021
c2702f5
FIX: improve the efficiency of full positive_u
2017pxy Jul 15, 2021
ab70a16
FIX: improve the efficiency of full positive_u
2017pxy Jul 15, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
.idea/
*.pyc
*.log
log_tensorboard/*
saved/
*.lprof
*.egg-info/
Expand Down
4 changes: 2 additions & 2 deletions conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ requirements:
- pandas >=1.0.5
- tqdm >=4.48.2
- pyyaml >=5.1.0
- matplotlib >=3.1.3
- scikit-learn >=0.23.2
- pytorch
- colorlog==4.7.2
- colorama==0.4.4
- tensorboard >=2.5.0
run:
- python
- numpy >=1.17.2
- scipy ==1.6.0
- pandas >=1.0.5
- tqdm >=4.48.2
- pyyaml >=5.1.0
- matplotlib >=3.1.3
- scikit-learn >=0.23.2
- pytorch
- colorlog==4.7.2
- colorama==0.4.4
- tensorboard >=2.5.0
test:
imports:
- recbole
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/config_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ model training and evaluation.
which will clips gradient norm of model. Defaults to ``None``.
- ``loss_decimal_place(int)``: The decimal place of training loss. Defaults to ``4``.
- ``weight_decay (float)`` : Weight decay (L2 penalty), used for `optimizer <https://pytorch.org/docs/stable/optim.html?highlight=weight_decay>`_. Default to ``0.0``.
- ``draw_loss_pic (bool)``: Draw the training loss line graph of model if it's ``True``, the pic is a PDF file and will be saved in your run directory after model training. Default to ``False``.



**Evaluation Setting**
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/data/atomic_files.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ For example, if you want to map the tokens of ``ent_id`` into the same space of
# inter/user/item/...: As usual
ent: [ent_id, ent_emb]

fields_in_same_space: [[ent_id, entity_id]]
alias_of_entity_id: [ent_id]
16 changes: 7 additions & 9 deletions docs/source/user_guide/data/data_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ Remove duplicated user-item interactions
Filter by value
''''''''''''''''''

- ``lowest_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] < v`` will be filtered. Defaults to ``None``.
- ``highest_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] > v`` will be filtered. Defaults to ``None``.
- ``equal_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] != v`` will be filtered. Defaults to ``None``.
- ``not_equal_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] == v`` will be filtered. Defaults to ``None``.
- ``val_interval (dict)``: Has the format ``{k (str): interval (str), ...}``, where ``interval `` can be set as ``[A,B]`` / ``[A,B)`` / ``(A,B)`` / ``(A,B]``. The rows whose ``feat[k]`` is in the interval ``interval`` will be retained. If you want to specify more than one interval, separate them with semicolon(s). For instance, ``{k: "[A,B);(C,D]"}`` can be adopted and rows whose ``feat[k]`` is in any specified interval will be retained. Defaults to ``None``, which means all rows will be retained.

Remove interation by user or item
'''''''''''''''''''''''''''''''''''
Expand All @@ -85,15 +82,16 @@ Remove interation by user or item
Filter by number of interactions
''''''''''''''''''''''''''''''''''''

- ``max_user_inter_num (int)`` : Users whose number of interactions is more than ``max_user_inter_num`` will be filtered. Defaults to ``None``.
- ``min_user_inter_num (int)`` : Users whose number of interactions is less than ``min_user_inter_num`` will be filtered. Defaults to ``0``.
- ``max_item_inter_num (int)`` : Items whose number of interactions is more than ``max_item_inter_num`` will be filtered. Defaults to ``None``.
- ``min_item_inter_num (int)`` : Items whose number of interactions is less than ``min_item_inter_num`` will be filtered. Defaults to ``0``.
- ``user_inter_num_interval (str)`` : Has the interval format, such as ``[A,B]`` / ``[A,B)`` / ``(A,B)`` / ``(A,B]``, where ``A`` and ``B`` are the endpoints of the interval and ``A <= B``. Users whose number of interactions is in the interval will be retained. Defaults to ``[0,inf)``.
- ``item_inter_num_interval (str)`` : Has the interval format, such as ``[A,B]`` / ``[A,B)`` / ``(A,B)`` / ``(A,B]``, where ``A`` and ``B`` are the endpoints of the interval and ``A <= B``. Items whose number of interactions is in the interval will be retained. Defaults to ``[0,inf)``.

Preprocessing
-----------------

- ``fields_in_same_space (list)`` : List of spaces. Space is a list of string similar to the fields' names. The fields in the same space will be remapped into the same index system. Note that if you want to make some fields remapped in the same space with entities, then just set ``fields_in_same_space = [entity_id, xxx, ...]``. (if ``ENTITY_ID_FIELD != 'entity_id'``, then change the ``'entity_id'`` in the above example.) Defaults to ``None``.
- ``alias_of_user_id (list)``: List of fields' names, which will be remapped into the same index system with ``USER_ID_FIELD``. Defaults to ``None``.
- ``alias_of_item_id (list)``: List of fields' names, which will be remapped into the same index system with ``ITEM_ID_FIELD``. Defaults to ``None``.
- ``alias_of_entity_id (list)``: List of fields' names, which will be remapped into the same index system with ``ENTITY_ID_FIELD``, ``HEAD_ENTITY_ID_FIELD`` and ``TAIL_ENTITY_ID_FIELD``. Defaults to ``None``.
- ``alias_of_relation_id (list)``: List of fields' names, which will be remapped into the same index system with ``RELATION_ID_FIELD``. Defaults to ``None``.
- ``preload_weight (dict)`` : Has the format ``{k (str): v (float)}, ...``. ``k`` if a token field, representing the IDs of each row of preloaded weight matrix. ``v`` is a float like fields. Each pair of ``u`` and ``v`` should be from the same atomic file. This arg can be used to load pretrained vectors. Defaults to ``None``.
- ``normalize_field (list)`` : List of filed names to be normalized. Note that only float like fields can be normalized. Defaults to ``None``.
- ``normalize_all (bool)`` : Normalize all the float like fields if ``True``. Defaults to ``True``.
Expand Down
4 changes: 1 addition & 3 deletions docs/source/user_guide/model/sequential/gru4reckg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ And then:
kg: [head_id, relation_id, tail_id]
link: [item_id, entity_id]
ent_feature: [ent_id, ent_vec]
fields_in_same_space: [
[ent_id, entity_id]
]
alias_of_entity_id: [ent_id]
preload_weight:
ent_id: ent_vec
additional_feat_suffix: [ent_feature]
Expand Down
6 changes: 2 additions & 4 deletions docs/source/user_guide/model/sequential/ksr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ And then:
link: [item_id, entity_id]
ent_feature: [ent_id, ent_vec]
rel_feature: [rel_id, rel_vec]
fields_in_same_space: [
[ent_id, entity_id]
[rel_id, relation_id]
]
alias_of_entity_id: [ent_id]
alias_of_relation_id: [rel_id]
preload_weight:
ent_id: ent_vec
rel_id: rel_vec
Expand Down
8 changes: 4 additions & 4 deletions docs/source/user_guide/usage/load_pretrained_embedding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ Secondly, update the args as (suppose that ``USER_ID_FIELD: user_id``):
load_col:
# inter/user/item/...: As usual
useremb: [uid, user_emb]
fields_in_same_space: [[uid, user_id]]
alias_of_user_id: [uid]
preload_weight:
uid: user_emb
uid: user_emb

Then, this additional embedding feature file will be loaded into the :class:`Dataset` object. These new features can be accessed as following:

Expand All @@ -39,6 +39,6 @@ In your model, user embedding matrix can be initialized by your pre-trained embe

class YourModel(GeneralRecommender):
def __init__(self, config, dataset):
pretrained_user_emb = dataset.get_preload_weight('uid')
self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb))
pretrained_user_emb = dataset.get_preload_weight('uid')
self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb))

1 change: 0 additions & 1 deletion recbole/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from recbole.config.configurator import Config
from recbole.config.eval_setting import EvalSetting
45 changes: 40 additions & 5 deletions recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# @Email : linzihan.super@foxmail.com

# UPDATE
# @Time : 2020/10/04, 2021/3/2, 2021/2/17
# @Author : Shanlei Mu, Yupeng Hou, Jiawei Guan
# @Email : slmu@ruc.edu.cn, houyupeng@ruc.edu.cn, Guanjw@ruc.edu.cn
# @Time : 2020/10/04, 2021/3/2, 2021/2/17, 2021/6/30
# @Author : Shanlei Mu, Yupeng Hou, Jiawei Guan, Xingyu Pan
# @Email : slmu@ruc.edu.cn, houyupeng@ruc.edu.cn, Guanjw@ruc.edu.cn, xy_pan@foxmail.com

"""
recbole.config.configurator
Expand All @@ -21,8 +21,7 @@

from recbole.evaluator import group_metrics, individual_metrics
from recbole.utils import get_model, Enum, EvaluatorType, ModelType, InputType, \
general_arguments, training_arguments, evaluation_arguments, dataset_arguments
from recbole.utils.utils import set_color
general_arguments, training_arguments, evaluation_arguments, dataset_arguments, set_color


class Config(object):
Expand Down Expand Up @@ -79,6 +78,7 @@ def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=
self._set_default_parameters()
self._init_device()
self._set_train_neg_sample_args()
self._set_eval_neg_sample_args()

def _init_parameters_category(self):
self.parameters = dict()
Expand Down Expand Up @@ -302,11 +302,30 @@ def _set_default_parameters(self):
valid_metric = self.final_config_dict['valid_metric'].split('@')[0]
self.final_config_dict['valid_metric_bigger'] = False if valid_metric.lower() in smaller_metric else True

topk = self.final_config_dict['topk']
if isinstance(topk,int):
self.final_config_dict['topk'] = [topk]

metrics = self.final_config_dict['metrics']
if isinstance(metrics, str):
self.final_config_dict['metrics'] = [metrics]

if 'additional_feat_suffix' in self.final_config_dict:
ad_suf = self.final_config_dict['additional_feat_suffix']
if isinstance(ad_suf, str):
self.final_config_dict['additional_feat_suffix'] = [ad_suf]

# eval_args checking
default_eval_args = {
'split': {'RS': [0.8, 0.1, 0.1]},
'order': 'RO',
'group_by': 'user',
'mode': 'full'
}
for op_args in default_eval_args:
if op_args not in self.final_config_dict['eval_args']:
self.final_config_dict['eval_args'][op_args] = default_eval_args[op_args]

def _init_device(self):
use_gpu = self.final_config_dict['use_gpu']
if use_gpu:
Expand All @@ -323,6 +342,22 @@ def _set_train_neg_sample_args(self):
else:
self.final_config_dict['train_neg_sample_args'] = {'strategy': 'none'}

def _set_eval_neg_sample_args(self):
eval_mode = self.final_config_dict['eval_args']['mode']
if eval_mode == 'none':
eval_neg_sample_args = {'strategy': 'none', 'distribution': 'none'}
elif eval_mode == 'full':
eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}
elif eval_mode[0:3] == 'uni':
sample_by = int(eval_mode[3:])
eval_neg_sample_args = {'strategy': 'by', 'by': sample_by, 'distribution': 'uniform'}
elif eval_mode[0:3] == 'pop':
sample_by = int(eval_mode[3:])
eval_neg_sample_args = {'strategy': 'by', 'by': sample_by, 'distribution': 'popularity'}
else:
raise ValueError(f'the mode [{eval_mode}] in eval_args is not supported.')
self.final_config_dict['eval_neg_sample_args'] = eval_neg_sample_args

def __setitem__(self, key, value):
if not isinstance(key, str):
raise TypeError("index must be a str.")
Expand Down
Loading