Skip to content

Commit

Permalink
Merge pull request #1 from 2017pxy/evaluator
Browse files Browse the repository at this point in the history
Refactor: refactor the dataloader
  • Loading branch information
guijiql authored Jul 17, 2021
2 parents b2663fe + ab70a16 commit 3331fbe
Show file tree
Hide file tree
Showing 67 changed files with 1,568 additions and 2,438 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
.idea/
*.pyc
*.log
log_tensorboard/*
saved/
*.lprof
*.egg-info/
Expand Down
4 changes: 2 additions & 2 deletions conda/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ requirements:
- pandas >=1.0.5
- tqdm >=4.48.2
- pyyaml >=5.1.0
- matplotlib >=3.1.3
- scikit-learn >=0.23.2
- pytorch
- colorlog==4.7.2
- colorama==0.4.4
- tensorboard >=2.5.0
run:
- python
- numpy >=1.17.2
- scipy ==1.6.0
- pandas >=1.0.5
- tqdm >=4.48.2
- pyyaml >=5.1.0
- matplotlib >=3.1.3
- scikit-learn >=0.23.2
- pytorch
- colorlog==4.7.2
- colorama==0.4.4
- tensorboard >=2.5.0
test:
imports:
- recbole
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/config_settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ model training and evaluation.
which will clips gradient norm of model. Defaults to ``None``.
- ``loss_decimal_place(int)``: The decimal place of training loss. Defaults to ``4``.
- ``weight_decay (float)`` : Weight decay (L2 penalty), used for `optimizer <https://pytorch.org/docs/stable/optim.html?highlight=weight_decay>`_. Default to ``0.0``.
- ``draw_loss_pic (bool)``: Draw the training loss line graph of model if it's ``True``, the pic is a PDF file and will be saved in your run directory after model training. Default to ``False``.



**Evaluation Setting**
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user_guide/data/atomic_files.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ For example, if you want to map the tokens of ``ent_id`` into the same space of
# inter/user/item/...: As usual
ent: [ent_id, ent_emb]
fields_in_same_space: [[ent_id, entity_id]]
alias_of_entity_id: [ent_id]
16 changes: 7 additions & 9 deletions docs/source/user_guide/data/data_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,7 @@ Remove duplicated user-item interactions
Filter by value
''''''''''''''''''

- ``lowest_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] < v`` will be filtered. Defaults to ``None``.
- ``highest_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] > v`` will be filtered. Defaults to ``None``.
- ``equal_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] != v`` will be filtered. Defaults to ``None``.
- ``not_equal_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] == v`` will be filtered. Defaults to ``None``.
- ``val_interval (dict)``: Has the format ``{k (str): interval (str), ...}``, where ``interval `` can be set as ``[A,B]`` / ``[A,B)`` / ``(A,B)`` / ``(A,B]``. The rows whose ``feat[k]`` is in the interval ``interval`` will be retained. If you want to specify more than one interval, separate them with semicolon(s). For instance, ``{k: "[A,B);(C,D]"}`` can be adopted and rows whose ``feat[k]`` is in any specified interval will be retained. Defaults to ``None``, which means all rows will be retained.

Remove interation by user or item
'''''''''''''''''''''''''''''''''''
Expand All @@ -85,15 +82,16 @@ Remove interation by user or item
Filter by number of interactions
''''''''''''''''''''''''''''''''''''

- ``max_user_inter_num (int)`` : Users whose number of interactions is more than ``max_user_inter_num`` will be filtered. Defaults to ``None``.
- ``min_user_inter_num (int)`` : Users whose number of interactions is less than ``min_user_inter_num`` will be filtered. Defaults to ``0``.
- ``max_item_inter_num (int)`` : Items whose number of interactions is more than ``max_item_inter_num`` will be filtered. Defaults to ``None``.
- ``min_item_inter_num (int)`` : Items whose number of interactions is less than ``min_item_inter_num`` will be filtered. Defaults to ``0``.
- ``user_inter_num_interval (str)`` : Has the interval format, such as ``[A,B]`` / ``[A,B)`` / ``(A,B)`` / ``(A,B]``, where ``A`` and ``B`` are the endpoints of the interval and ``A <= B``. Users whose number of interactions is in the interval will be retained. Defaults to ``[0,inf)``.
- ``item_inter_num_interval (str)`` : Has the interval format, such as ``[A,B]`` / ``[A,B)`` / ``(A,B)`` / ``(A,B]``, where ``A`` and ``B`` are the endpoints of the interval and ``A <= B``. Items whose number of interactions is in the interval will be retained. Defaults to ``[0,inf)``.

Preprocessing
-----------------

- ``fields_in_same_space (list)`` : List of spaces. Space is a list of string similar to the fields' names. The fields in the same space will be remapped into the same index system. Note that if you want to make some fields remapped in the same space with entities, then just set ``fields_in_same_space = [entity_id, xxx, ...]``. (if ``ENTITY_ID_FIELD != 'entity_id'``, then change the ``'entity_id'`` in the above example.) Defaults to ``None``.
- ``alias_of_user_id (list)``: List of fields' names, which will be remapped into the same index system with ``USER_ID_FIELD``. Defaults to ``None``.
- ``alias_of_item_id (list)``: List of fields' names, which will be remapped into the same index system with ``ITEM_ID_FIELD``. Defaults to ``None``.
- ``alias_of_entity_id (list)``: List of fields' names, which will be remapped into the same index system with ``ENTITY_ID_FIELD``, ``HEAD_ENTITY_ID_FIELD`` and ``TAIL_ENTITY_ID_FIELD``. Defaults to ``None``.
- ``alias_of_relation_id (list)``: List of fields' names, which will be remapped into the same index system with ``RELATION_ID_FIELD``. Defaults to ``None``.
- ``preload_weight (dict)`` : Has the format ``{k (str): v (float)}, ...``. ``k`` if a token field, representing the IDs of each row of preloaded weight matrix. ``v`` is a float like fields. Each pair of ``u`` and ``v`` should be from the same atomic file. This arg can be used to load pretrained vectors. Defaults to ``None``.
- ``normalize_field (list)`` : List of filed names to be normalized. Note that only float like fields can be normalized. Defaults to ``None``.
- ``normalize_all (bool)`` : Normalize all the float like fields if ``True``. Defaults to ``True``.
Expand Down
4 changes: 1 addition & 3 deletions docs/source/user_guide/model/sequential/gru4reckg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ And then:
kg: [head_id, relation_id, tail_id]
link: [item_id, entity_id]
ent_feature: [ent_id, ent_vec]
fields_in_same_space: [
[ent_id, entity_id]
]
alias_of_entity_id: [ent_id]
preload_weight:
ent_id: ent_vec
additional_feat_suffix: [ent_feature]
Expand Down
6 changes: 2 additions & 4 deletions docs/source/user_guide/model/sequential/ksr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@ And then:
link: [item_id, entity_id]
ent_feature: [ent_id, ent_vec]
rel_feature: [rel_id, rel_vec]
fields_in_same_space: [
[ent_id, entity_id]
[rel_id, relation_id]
]
alias_of_entity_id: [ent_id]
alias_of_relation_id: [rel_id]
preload_weight:
ent_id: ent_vec
rel_id: rel_vec
Expand Down
8 changes: 4 additions & 4 deletions docs/source/user_guide/usage/load_pretrained_embedding.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ Secondly, update the args as (suppose that ``USER_ID_FIELD: user_id``):
load_col:
# inter/user/item/...: As usual
useremb: [uid, user_emb]
fields_in_same_space: [[uid, user_id]]
alias_of_user_id: [uid]
preload_weight:
uid: user_emb
uid: user_emb
Then, this additional embedding feature file will be loaded into the :class:`Dataset` object. These new features can be accessed as following:

Expand All @@ -39,6 +39,6 @@ In your model, user embedding matrix can be initialized by your pre-trained embe
class YourModel(GeneralRecommender):
def __init__(self, config, dataset):
pretrained_user_emb = dataset.get_preload_weight('uid')
self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb))
pretrained_user_emb = dataset.get_preload_weight('uid')
self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb))
1 change: 0 additions & 1 deletion recbole/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from recbole.config.configurator import Config
from recbole.config.eval_setting import EvalSetting
45 changes: 40 additions & 5 deletions recbole/config/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# @Email : linzihan.super@foxmail.com

# UPDATE
# @Time : 2020/10/04, 2021/3/2, 2021/2/17
# @Author : Shanlei Mu, Yupeng Hou, Jiawei Guan
# @Email : slmu@ruc.edu.cn, houyupeng@ruc.edu.cn, Guanjw@ruc.edu.cn
# @Time : 2020/10/04, 2021/3/2, 2021/2/17, 2021/6/30
# @Author : Shanlei Mu, Yupeng Hou, Jiawei Guan, Xingyu Pan
# @Email : slmu@ruc.edu.cn, houyupeng@ruc.edu.cn, Guanjw@ruc.edu.cn, xy_pan@foxmail.com

"""
recbole.config.configurator
Expand All @@ -21,8 +21,7 @@

from recbole.evaluator import group_metrics, individual_metrics
from recbole.utils import get_model, Enum, EvaluatorType, ModelType, InputType, \
general_arguments, training_arguments, evaluation_arguments, dataset_arguments
from recbole.utils.utils import set_color
general_arguments, training_arguments, evaluation_arguments, dataset_arguments, set_color


class Config(object):
Expand Down Expand Up @@ -79,6 +78,7 @@ def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=
self._set_default_parameters()
self._init_device()
self._set_train_neg_sample_args()
self._set_eval_neg_sample_args()

def _init_parameters_category(self):
self.parameters = dict()
Expand Down Expand Up @@ -302,11 +302,30 @@ def _set_default_parameters(self):
valid_metric = self.final_config_dict['valid_metric'].split('@')[0]
self.final_config_dict['valid_metric_bigger'] = False if valid_metric.lower() in smaller_metric else True

topk = self.final_config_dict['topk']
if isinstance(topk,int):
self.final_config_dict['topk'] = [topk]

metrics = self.final_config_dict['metrics']
if isinstance(metrics, str):
self.final_config_dict['metrics'] = [metrics]

if 'additional_feat_suffix' in self.final_config_dict:
ad_suf = self.final_config_dict['additional_feat_suffix']
if isinstance(ad_suf, str):
self.final_config_dict['additional_feat_suffix'] = [ad_suf]

# eval_args checking
default_eval_args = {
'split': {'RS': [0.8, 0.1, 0.1]},
'order': 'RO',
'group_by': 'user',
'mode': 'full'
}
for op_args in default_eval_args:
if op_args not in self.final_config_dict['eval_args']:
self.final_config_dict['eval_args'][op_args] = default_eval_args[op_args]

def _init_device(self):
use_gpu = self.final_config_dict['use_gpu']
if use_gpu:
Expand All @@ -323,6 +342,22 @@ def _set_train_neg_sample_args(self):
else:
self.final_config_dict['train_neg_sample_args'] = {'strategy': 'none'}

def _set_eval_neg_sample_args(self):
eval_mode = self.final_config_dict['eval_args']['mode']
if eval_mode == 'none':
eval_neg_sample_args = {'strategy': 'none', 'distribution': 'none'}
elif eval_mode == 'full':
eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'}
elif eval_mode[0:3] == 'uni':
sample_by = int(eval_mode[3:])
eval_neg_sample_args = {'strategy': 'by', 'by': sample_by, 'distribution': 'uniform'}
elif eval_mode[0:3] == 'pop':
sample_by = int(eval_mode[3:])
eval_neg_sample_args = {'strategy': 'by', 'by': sample_by, 'distribution': 'popularity'}
else:
raise ValueError(f'the mode [{eval_mode}] in eval_args is not supported.')
self.final_config_dict['eval_neg_sample_args'] = eval_neg_sample_args

def __setitem__(self, key, value):
if not isinstance(key, str):
raise TypeError("index must be a str.")
Expand Down
Loading

0 comments on commit 3331fbe

Please sign in to comment.