diff --git a/.gitignore b/.gitignore index 660977853..f6f7e3498 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ .idea/ *.pyc *.log +log_tensorboard/* saved/ *.lprof *.egg-info/ diff --git a/conda/meta.yaml b/conda/meta.yaml index de10469f1..7963dba0b 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -15,11 +15,11 @@ requirements: - pandas >=1.0.5 - tqdm >=4.48.2 - pyyaml >=5.1.0 - - matplotlib >=3.1.3 - scikit-learn >=0.23.2 - pytorch - colorlog==4.7.2 - colorama==0.4.4 + - tensorboard >=2.5.0 run: - python - numpy >=1.17.2 @@ -27,11 +27,11 @@ requirements: - pandas >=1.0.5 - tqdm >=4.48.2 - pyyaml >=5.1.0 - - matplotlib >=3.1.3 - scikit-learn >=0.23.2 - pytorch - colorlog==4.7.2 - colorama==0.4.4 + - tensorboard >=2.5.0 test: imports: - recbole diff --git a/docs/source/user_guide/config_settings.rst b/docs/source/user_guide/config_settings.rst index 41bf7c5f7..52d8a135e 100644 --- a/docs/source/user_guide/config_settings.rst +++ b/docs/source/user_guide/config_settings.rst @@ -56,7 +56,7 @@ model training and evaluation. which will clips gradient norm of model. Defaults to ``None``. - ``loss_decimal_place(int)``: The decimal place of training loss. Defaults to ``4``. - ``weight_decay (float)`` : Weight decay (L2 penalty), used for `optimizer `_. Default to ``0.0``. -- ``draw_loss_pic (bool)``: Draw the training loss line graph of model if it's ``True``, the pic is a PDF file and will be saved in your run directory after model training. Default to ``False``. + **Evaluation Setting** diff --git a/docs/source/user_guide/data/atomic_files.rst b/docs/source/user_guide/data/atomic_files.rst index 9e74dd4d7..ebac34ce3 100644 --- a/docs/source/user_guide/data/atomic_files.rst +++ b/docs/source/user_guide/data/atomic_files.rst @@ -141,4 +141,4 @@ For example, if you want to map the tokens of ``ent_id`` into the same space of # inter/user/item/...: As usual ent: [ent_id, ent_emb] - fields_in_same_space: [[ent_id, entity_id]] + alias_of_entity_id: [ent_id] diff --git a/docs/source/user_guide/data/data_args.rst b/docs/source/user_guide/data/data_args.rst index 497eb65e6..97ce58835 100644 --- a/docs/source/user_guide/data/data_args.rst +++ b/docs/source/user_guide/data/data_args.rst @@ -72,10 +72,7 @@ Remove duplicated user-item interactions Filter by value '''''''''''''''''' -- ``lowest_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] < v`` will be filtered. Defaults to ``None``. -- ``highest_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] > v`` will be filtered. Defaults to ``None``. -- ``equal_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] != v`` will be filtered. Defaults to ``None``. -- ``not_equal_val (dict)`` : Has the format ``{k (str): v (float)}, ...``. The rows whose ``feat[k] == v`` will be filtered. Defaults to ``None``. +- ``val_interval (dict)``: Has the format ``{k (str): interval (str), ...}``, where ``interval `` can be set as ``[A,B]`` / ``[A,B)`` / ``(A,B)`` / ``(A,B]``. The rows whose ``feat[k]`` is in the interval ``interval`` will be retained. If you want to specify more than one interval, separate them with semicolon(s). For instance, ``{k: "[A,B);(C,D]"}`` can be adopted and rows whose ``feat[k]`` is in any specified interval will be retained. Defaults to ``None``, which means all rows will be retained. Remove interation by user or item ''''''''''''''''''''''''''''''''''' @@ -85,15 +82,16 @@ Remove interation by user or item Filter by number of interactions '''''''''''''''''''''''''''''''''''' -- ``max_user_inter_num (int)`` : Users whose number of interactions is more than ``max_user_inter_num`` will be filtered. Defaults to ``None``. -- ``min_user_inter_num (int)`` : Users whose number of interactions is less than ``min_user_inter_num`` will be filtered. Defaults to ``0``. -- ``max_item_inter_num (int)`` : Items whose number of interactions is more than ``max_item_inter_num`` will be filtered. Defaults to ``None``. -- ``min_item_inter_num (int)`` : Items whose number of interactions is less than ``min_item_inter_num`` will be filtered. Defaults to ``0``. +- ``user_inter_num_interval (str)`` : Has the interval format, such as ``[A,B]`` / ``[A,B)`` / ``(A,B)`` / ``(A,B]``, where ``A`` and ``B`` are the endpoints of the interval and ``A <= B``. Users whose number of interactions is in the interval will be retained. Defaults to ``[0,inf)``. +- ``item_inter_num_interval (str)`` : Has the interval format, such as ``[A,B]`` / ``[A,B)`` / ``(A,B)`` / ``(A,B]``, where ``A`` and ``B`` are the endpoints of the interval and ``A <= B``. Items whose number of interactions is in the interval will be retained. Defaults to ``[0,inf)``. Preprocessing ----------------- -- ``fields_in_same_space (list)`` : List of spaces. Space is a list of string similar to the fields' names. The fields in the same space will be remapped into the same index system. Note that if you want to make some fields remapped in the same space with entities, then just set ``fields_in_same_space = [entity_id, xxx, ...]``. (if ``ENTITY_ID_FIELD != 'entity_id'``, then change the ``'entity_id'`` in the above example.) Defaults to ``None``. +- ``alias_of_user_id (list)``: List of fields' names, which will be remapped into the same index system with ``USER_ID_FIELD``. Defaults to ``None``. +- ``alias_of_item_id (list)``: List of fields' names, which will be remapped into the same index system with ``ITEM_ID_FIELD``. Defaults to ``None``. +- ``alias_of_entity_id (list)``: List of fields' names, which will be remapped into the same index system with ``ENTITY_ID_FIELD``, ``HEAD_ENTITY_ID_FIELD`` and ``TAIL_ENTITY_ID_FIELD``. Defaults to ``None``. +- ``alias_of_relation_id (list)``: List of fields' names, which will be remapped into the same index system with ``RELATION_ID_FIELD``. Defaults to ``None``. - ``preload_weight (dict)`` : Has the format ``{k (str): v (float)}, ...``. ``k`` if a token field, representing the IDs of each row of preloaded weight matrix. ``v`` is a float like fields. Each pair of ``u`` and ``v`` should be from the same atomic file. This arg can be used to load pretrained vectors. Defaults to ``None``. - ``normalize_field (list)`` : List of filed names to be normalized. Note that only float like fields can be normalized. Defaults to ``None``. - ``normalize_all (bool)`` : Normalize all the float like fields if ``True``. Defaults to ``True``. diff --git a/docs/source/user_guide/model/sequential/gru4reckg.rst b/docs/source/user_guide/model/sequential/gru4reckg.rst index f1653cabf..69dacdaa4 100644 --- a/docs/source/user_guide/model/sequential/gru4reckg.rst +++ b/docs/source/user_guide/model/sequential/gru4reckg.rst @@ -46,9 +46,7 @@ And then: kg: [head_id, relation_id, tail_id] link: [item_id, entity_id] ent_feature: [ent_id, ent_vec] - fields_in_same_space: [ - [ent_id, entity_id] - ] + alias_of_entity_id: [ent_id] preload_weight: ent_id: ent_vec additional_feat_suffix: [ent_feature] diff --git a/docs/source/user_guide/model/sequential/ksr.rst b/docs/source/user_guide/model/sequential/ksr.rst index 8085e71b4..ff24e2184 100644 --- a/docs/source/user_guide/model/sequential/ksr.rst +++ b/docs/source/user_guide/model/sequential/ksr.rst @@ -58,10 +58,8 @@ And then: link: [item_id, entity_id] ent_feature: [ent_id, ent_vec] rel_feature: [rel_id, rel_vec] - fields_in_same_space: [ - [ent_id, entity_id] - [rel_id, relation_id] - ] + alias_of_entity_id: [ent_id] + alias_of_relation_id: [rel_id] preload_weight: ent_id: ent_vec rel_id: rel_vec diff --git a/docs/source/user_guide/usage/load_pretrained_embedding.rst b/docs/source/user_guide/usage/load_pretrained_embedding.rst index c8f3c010a..2e50a31ed 100644 --- a/docs/source/user_guide/usage/load_pretrained_embedding.rst +++ b/docs/source/user_guide/usage/load_pretrained_embedding.rst @@ -22,9 +22,9 @@ Secondly, update the args as (suppose that ``USER_ID_FIELD: user_id``): load_col: # inter/user/item/...: As usual useremb: [uid, user_emb] - fields_in_same_space: [[uid, user_id]] + alias_of_user_id: [uid] preload_weight: - uid: user_emb + uid: user_emb Then, this additional embedding feature file will be loaded into the :class:`Dataset` object. These new features can be accessed as following: @@ -39,6 +39,6 @@ In your model, user embedding matrix can be initialized by your pre-trained embe class YourModel(GeneralRecommender): def __init__(self, config, dataset): - pretrained_user_emb = dataset.get_preload_weight('uid') - self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb)) + pretrained_user_emb = dataset.get_preload_weight('uid') + self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb)) diff --git a/recbole/config/__init__.py b/recbole/config/__init__.py index e4ee33c03..71e55a8ea 100644 --- a/recbole/config/__init__.py +++ b/recbole/config/__init__.py @@ -1,2 +1 @@ from recbole.config.configurator import Config -from recbole.config.eval_setting import EvalSetting diff --git a/recbole/config/configurator.py b/recbole/config/configurator.py index c71435c2a..4e1ac3182 100644 --- a/recbole/config/configurator.py +++ b/recbole/config/configurator.py @@ -3,9 +3,9 @@ # @Email : linzihan.super@foxmail.com # UPDATE -# @Time : 2020/10/04, 2021/3/2, 2021/2/17 -# @Author : Shanlei Mu, Yupeng Hou, Jiawei Guan -# @Email : slmu@ruc.edu.cn, houyupeng@ruc.edu.cn, Guanjw@ruc.edu.cn +# @Time : 2020/10/04, 2021/3/2, 2021/2/17, 2021/6/30 +# @Author : Shanlei Mu, Yupeng Hou, Jiawei Guan, Xingyu Pan +# @Email : slmu@ruc.edu.cn, houyupeng@ruc.edu.cn, Guanjw@ruc.edu.cn, xy_pan@foxmail.com """ recbole.config.configurator @@ -21,8 +21,7 @@ from recbole.evaluator import group_metrics, individual_metrics from recbole.utils import get_model, Enum, EvaluatorType, ModelType, InputType, \ - general_arguments, training_arguments, evaluation_arguments, dataset_arguments -from recbole.utils.utils import set_color + general_arguments, training_arguments, evaluation_arguments, dataset_arguments, set_color class Config(object): @@ -79,6 +78,7 @@ def __init__(self, model=None, dataset=None, config_file_list=None, config_dict= self._set_default_parameters() self._init_device() self._set_train_neg_sample_args() + self._set_eval_neg_sample_args() def _init_parameters_category(self): self.parameters = dict() @@ -302,11 +302,30 @@ def _set_default_parameters(self): valid_metric = self.final_config_dict['valid_metric'].split('@')[0] self.final_config_dict['valid_metric_bigger'] = False if valid_metric.lower() in smaller_metric else True + topk = self.final_config_dict['topk'] + if isinstance(topk,int): + self.final_config_dict['topk'] = [topk] + + metrics = self.final_config_dict['metrics'] + if isinstance(metrics, str): + self.final_config_dict['metrics'] = [metrics] + if 'additional_feat_suffix' in self.final_config_dict: ad_suf = self.final_config_dict['additional_feat_suffix'] if isinstance(ad_suf, str): self.final_config_dict['additional_feat_suffix'] = [ad_suf] + # eval_args checking + default_eval_args = { + 'split': {'RS': [0.8, 0.1, 0.1]}, + 'order': 'RO', + 'group_by': 'user', + 'mode': 'full' + } + for op_args in default_eval_args: + if op_args not in self.final_config_dict['eval_args']: + self.final_config_dict['eval_args'][op_args] = default_eval_args[op_args] + def _init_device(self): use_gpu = self.final_config_dict['use_gpu'] if use_gpu: @@ -323,6 +342,22 @@ def _set_train_neg_sample_args(self): else: self.final_config_dict['train_neg_sample_args'] = {'strategy': 'none'} + def _set_eval_neg_sample_args(self): + eval_mode = self.final_config_dict['eval_args']['mode'] + if eval_mode == 'none': + eval_neg_sample_args = {'strategy': 'none', 'distribution': 'none'} + elif eval_mode == 'full': + eval_neg_sample_args = {'strategy': 'full', 'distribution': 'uniform'} + elif eval_mode[0:3] == 'uni': + sample_by = int(eval_mode[3:]) + eval_neg_sample_args = {'strategy': 'by', 'by': sample_by, 'distribution': 'uniform'} + elif eval_mode[0:3] == 'pop': + sample_by = int(eval_mode[3:]) + eval_neg_sample_args = {'strategy': 'by', 'by': sample_by, 'distribution': 'popularity'} + else: + raise ValueError(f'the mode [{eval_mode}] in eval_args is not supported.') + self.final_config_dict['eval_neg_sample_args'] = eval_neg_sample_args + def __setitem__(self, key, value): if not isinstance(key, str): raise TypeError("index must be a str.") diff --git a/recbole/config/eval_setting.py b/recbole/config/eval_setting.py deleted file mode 100644 index c67386ded..000000000 --- a/recbole/config/eval_setting.py +++ /dev/null @@ -1,391 +0,0 @@ -# @Time : 2020/7/20 -# @Author : Yupeng Hou -# @Email : houyupeng@ruc.edu.cn - -# UPDATE: -# @Time : 2020/10/22, 2020/8/31, 2021/3/1 -# @Author : Yupeng Hou, Yushuo Chen, Jiawei Guan -# @Email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, guanjw@ruc.edu.cn - -""" -recbole.config.eval_setting -################################ -""" - -from recbole.utils.utils import set_color - - -class EvalSetting(object): - """Class containing settings about model evaluation. - - Evaluation setting contains four parts: - * Group - * Sort - * Split - * Negative Sample - - APIs are provided for users to set up or modify their evaluation setting easily and clearly. - - Besides, some presets are provided, which is more recommended. - - For example: - RO: Random Ordering - TO: Temporal Ordering - - RS: Ratio-based Splitting - LS: Leave-one-out Splitting - - full: adopt the entire item set (excluding ground-truth items) for ranking - uniXX: uniform sampling XX items while negative sampling - popXX: popularity-based sampling XX items while negative sampling - - Note that records are grouped by user_id by default if you use these presets. - - Thus you can use `RO_RS, full` to represent Shuffle, Grouped by user, Ratio-based Splitting - and Evaluate all non-ground-truth items. - - Check out *Revisiting Alternative Experimental Settings for Evaluating Top-N Item Recommendation Algorithms* - Wayne Xin Zhao et.al. CIKM 2020 to figure out the details about presets of evaluation settings. - - Args: - config (Config): Global configuration object. - - Attributes: - group_field (str or None): Don't group if ``None``, else group by field before splitting. - Usually records are grouped by user id. - - ordering_args (dict): Args about ordering. - Usually records are sorted by timestamp, or shuffled. - - split_args (dict): Args about splitting. - usually records are split by ratio (eg. 8:1:1), - or by 'leave one out' strategy, which means the last purchase record - of one user is used for evaluation. - - neg_sample_args (dict): Args about negative sampling. - Negative sample is used wildly in training and evaluating. - - We provide two strategies: - - - ``neg_sample_by``: sample several negative records for each positive records. - - ``full_sort``: don't negative sample, while all unused items are used for evaluation. - - """ - - def __init__(self, config): - self.config = config - - self.group_field = None - self.ordering_args = None - self.split_args = None - self.neg_sample_args = {'strategy': 'none'} - - self.es_str = [_.strip() for _ in config['eval_setting'].split(',')] - self.set_ordering_and_splitting(self.es_str[0]) - if len(self.es_str) > 1: - if getattr(self, self.es_str[1], None) == None: - raise ValueError('Incorrect setting of negative sampling.') - getattr(self, self.es_str[1])() - presetting_args = ['group_field', 'ordering_args', 'split_args', 'neg_sample_args'] - for args in presetting_args: - if config[args] is not None: - setattr(self, args, config[args]) - - def __str__(self): - info = [set_color('Evaluation Setting:', 'pink')] - - if self.group_field: - info.append(set_color('Group by', 'blue') + f' {self.group_field}') - else: - info.append(set_color('No Grouping', 'yellow')) - - if self.ordering_args is not None and self.ordering_args['strategy'] != 'none': - info.append(set_color('Ordering', 'blue') + f': {self.ordering_args}') - else: - info.append(set_color('No Ordering', 'yellow')) - - if self.split_args is not None and self.split_args['strategy'] != 'none': - info.append(set_color('Splitting', 'blue') + f': {self.split_args}') - else: - info.append(set_color('No Splitting', 'yellow')) - - if self.neg_sample_args is not None and self.neg_sample_args['strategy'] != 'none': - info.append(set_color('Negative Sampling', 'blue') + f': {self.neg_sample_args}') - else: - info.append(set_color('No Negative Sampling', 'yellow')) - - return '\n\t'.join(info) - - def __repr__(self): - return self.__str__() - - def group_by(self, field=None): - """Setting about group - - Args: - field (str): The field of dataset grouped by, default None (Not Grouping) - - Example: - >>> es.group_by('month') - >>> es.group_by_user() - """ - self.group_field = field - - def group_by_user(self): - """Group by user - - Note: - Requires ``USER_ID_FIELD`` in config - """ - self.group_field = self.config['USER_ID_FIELD'] - - def set_ordering(self, strategy='none', **kwargs): - """Setting about ordering - - Args: - strategy (str): Either ``none``, ``shuffle`` or ``by`` - field (str or list of str): Name or list of names - ascending (bool or list of bool): Sort ascending vs. descending. Specify list for multiple sort orders. - If this is a list of bools, must match the length of the field - - Example: - >>> es.set_ordering('shuffle') - >>> es.set_ordering('by', field='timestamp') - >>> es.set_ordering('by', field=['timestamp', 'price'], ascending=[True, False]) - - or - - >>> es.random_ordering() - >>> es.sort_by('timestamp') # ascending default - >>> es.sort_by(field=['timestamp', 'price'], ascending=[True, False]) - """ - legal_strategy = {'none', 'shuffle', 'by'} - if strategy not in legal_strategy: - raise ValueError('Ordering Strategy [{}] should in {}'.format(strategy, list(legal_strategy))) - self.ordering_args = {'strategy': strategy} - self.ordering_args.update(kwargs) - - def random_ordering(self): - """Shuffle Setting - """ - self.set_ordering('shuffle') - - def sort_by(self, field, ascending=True): - """Setting about Sorting. - - Similar with pandas' sort_values_ - - .. _sort_values: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.sort_values.html?highlight=sort_values#pandas.DataFrame.sort_values - - Args: - field (str or list of str): Name or list of names - ascending (bool or list of bool): Sort ascending vs. descending. Specify list for multiple sort orders. - If this is a list of bool, must match the length of the field - """ - self.set_ordering('by', field=field, ascending=ascending) - - def temporal_ordering(self): - """Setting about Sorting by timestamp. - - Note: - Requires `TIME_FIELD` in config - """ - self.sort_by(field=self.config['TIME_FIELD']) - - def set_splitting(self, strategy='none', **kwargs): - """Setting about split method - - Args: - strategy (str): Either ``none``, ``by_ratio``, ``by_value`` or ``loo``. - ratios (list of float): Dataset will be splited into `len(ratios)` parts. - field (str): Split by values of field. - values (list of float or float): Dataset will be splited into `len(values) + 1` parts. - The first part will be interactions whose field value in (\\*, values[0]]. - ascending (bool): Order of values after splitting. - - Example: - >>> es.leave_one_out() - >>> es.split_by_ratio(ratios=[0.8, 0.1, 0.1]) - >>> es.split_by_value(field='month', values=[6, 7], ascending=False) # (*, 7], (7, 6], (6, *) - """ - legal_strategy = {'none', 'by_ratio', 'by_value', 'loo'} - if strategy not in legal_strategy: - raise ValueError('Split Strategy [{}] should in {}'.format(strategy, list(legal_strategy))) - if strategy == 'loo' and self.group_field is None: - raise ValueError('Leave-One-Out request group firstly') - self.split_args = {'strategy': strategy} - self.split_args.update(kwargs) - - def leave_one_out(self, leave_one_num=1): - """ Setting about Splitting by 'leave-one-out' strategy. - - Note: - Requires setting group by. - - Args: - leave_one_num (int): number of sub datasets for evaluation. - E.g. ``leave_one_num = 2`` if you have one validation dataset and one test dataset. - """ - if self.group_field is None: - raise ValueError('Leave one out request grouped dataset, please set group field.') - self.set_splitting(strategy='loo', leave_one_num=leave_one_num) - - def split_by_ratio(self, ratios): - """ Setting about Ratio-based Splitting. - - Args: - ratios (list of float): ratio of each part. - No need to normalize. It's ok with either `[0.8, 0.1, 0.1]`, `[8, 1, 1]` or `[56, 7, 7]` - """ - if not isinstance(ratios, list): - raise ValueError('ratios [{}] should be list'.format(ratios)) - self.set_splitting(strategy='by_ratio', ratios=ratios) - - def _split_by_value(self, field, values, ascending=True): - raise NotImplementedError('Split by value has not been implemented.') - if not isinstance(field, str): - raise ValueError('field [{}] should be str'.format(field)) - if not isinstance(values, list): - values = [values] - values.sort(reverse=(not ascending)) - self.set_splitting(strategy='by_value', values=values, ascending=ascending) - - def set_neg_sampling(self, strategy='none', distribution='uniform', **kwargs): - """Setting about negative sampling - - Args: - strategy (str): Either ``none``, ``full`` or ``by``. - by (int): Negative Sampling `by` neg cases for one pos case. - distribution (str): distribution of sampler, either 'uniform' or 'popularity'. - - Example: - >>> es.full() - >>> es.neg_sample_by(1) - """ - legal_strategy = {'none', 'full', 'by'} - if strategy not in legal_strategy: - raise ValueError('Negative Sampling Strategy [{}] should in {}'.format(strategy, list(legal_strategy))) - if strategy == 'full' and distribution != 'uniform': - raise ValueError('Full Sort can not be sampled by distribution [{}]'.format(distribution)) - self.neg_sample_args = {'strategy': strategy, 'distribution': distribution} - self.neg_sample_args.update(kwargs) - - def neg_sample_by(self, by, distribution='uniform'): - """Setting about negative sampling by, which means sample several negative records for each positive records. - - Args: - by (int): The number of neg cases for one pos case. - distribution (str): distribution of sampler, either ``uniform`` or ``popularity``. - """ - self.set_neg_sampling(strategy='by', by=by, distribution=distribution) - - def set_ordering_and_splitting(self, es_str): - """Setting about ordering and split method. - - Args: - es_str (str): Ordering and splitting method string. Either ``RO_RS``, ``RO_LS``, ``TO_RS`` or ``TO_LS``. - """ - args = es_str.split('_') - if len(args) != 2: - raise ValueError(f'`{es_str}` is invalid eval_setting.') - ordering_args, split_args = args - - if self.config['group_by_user']: - self.group_by_user() - - if ordering_args == 'RO': - self.random_ordering() - elif ordering_args == 'TO': - self.temporal_ordering() - else: - raise NotImplementedError(f'Ordering args `{ordering_args}` is not implemented.') - - if split_args == 'RS': - ratios = self.config['split_ratio'] - if ratios is None: - raise ValueError('`ratios` should be set if `RS` is set.') - self.split_by_ratio(ratios) - elif split_args == 'LS': - leave_one_num = self.config['leave_one_num'] - if leave_one_num is None: - raise ValueError('`leave_one_num` should be set if `LS` is set.') - self.leave_one_out(leave_one_num=leave_one_num) - else: - raise NotImplementedError(f'Split args `{split_args}` is not implemented.') - - def RO_RS(self, ratios=(0.8, 0.1, 0.1), group_by_user=True): - """Preset about Random Ordering and Ratio-based Splitting. - - Args: - ratios (list of float): ratio of each part. - No need to normalize. It's ok with either ``[0.8, 0.1, 0.1]``, ``[8, 1, 1]`` or ``[56, 7, 7]`` - group_by_user (bool): set group field to user_id if True - """ - if group_by_user: - self.group_by_user() - self.random_ordering() - self.split_by_ratio(ratios) - - def TO_RS(self, ratios=(0.8, 0.1, 0.1), group_by_user=True): - """Preset about Temporal Ordering and Ratio-based Splitting. - - Args: - ratios (list of float): ratio of each part. - No need to normalize. It's ok with either ``[0.8, 0.1, 0.1]``, ``[8, 1, 1]`` or ``[56, 7, 7]`` - group_by_user (bool): set group field to user_id if True - """ - if group_by_user: - self.group_by_user() - self.temporal_ordering() - self.split_by_ratio(ratios) - - def RO_LS(self, leave_one_num=1, group_by_user=True): - """Preset about Random Ordering and Leave-one-out Splitting. - - Args: - leave_one_num (int): number of sub datasets for evaluation. - E.g. ``leave_one_num=2`` if you have one validation dataset and one test dataset. - group_by_user (bool): set group field to user_id if True - """ - if group_by_user: - self.group_by_user() - self.random_ordering() - self.leave_one_out(leave_one_num=leave_one_num) - - def TO_LS(self, leave_one_num=1, group_by_user=True): - """Preset about Temporal Ordering and Leave-one-out Splitting. - - Args: - leave_one_num (int): number of sub datasets for evaluation. - E.g. ``leave_one_num=2`` if you have one validation dataset and one test dataset. - group_by_user (bool): set group field to user_id if True - """ - if group_by_user: - self.group_by_user() - self.temporal_ordering() - self.leave_one_out(leave_one_num=leave_one_num) - - def uni100(self): - """Preset about uniform sampling 100 items for each positive records while negative sampling. - """ - self.neg_sample_by(100) - - def pop100(self): - """Preset about popularity-based sampling 100 items for each positive records while negative sampling. - """ - self.neg_sample_by(100, distribution='popularity') - - def uni1000(self): - """Preset about uniform sampling 1000 items for each positive records while negative sampling. - """ - self.neg_sample_by(1000) - - def pop1000(self): - """Preset about popularity-based sampling 1000 items for each positive records while negative sampling. - """ - self.neg_sample_by(1000, distribution='popularity') - - def full(self): - """Preset about adopt the entire item set (excluding ground-truth items) for ranking. - """ - self.set_neg_sampling(strategy='full') diff --git a/recbole/data/dataloader/__init__.py b/recbole/data/dataloader/__init__.py index dce69e125..d0a3a0a31 100644 --- a/recbole/data/dataloader/__init__.py +++ b/recbole/data/dataloader/__init__.py @@ -1,9 +1,4 @@ from recbole.data.dataloader.abstract_dataloader import * -from recbole.data.dataloader.neg_sample_mixin import * from recbole.data.dataloader.general_dataloader import * -from recbole.data.dataloader.context_dataloader import * -from recbole.data.dataloader.sequential_dataloader import * -from recbole.data.dataloader.dien_dataloader import * from recbole.data.dataloader.knowledge_dataloader import * -from recbole.data.dataloader.decisiontree_dataloader import * from recbole.data.dataloader.user_dataloader import * diff --git a/recbole/data/dataloader/abstract_dataloader.py b/recbole/data/dataloader/abstract_dataloader.py index 900fc9a02..58a662088 100644 --- a/recbole/data/dataloader/abstract_dataloader.py +++ b/recbole/data/dataloader/abstract_dataloader.py @@ -15,10 +15,13 @@ import math from logging import getLogger -from recbole.utils import InputType +import torch +from recbole.data.interaction import Interaction +from recbole.utils import InputType, FeatureType, FeatureSource -class AbstractDataLoader(object): + +class AbstractDataLoader: """:class:`AbstractDataLoader` is an abstract object which would return a batch of data which is loaded by :class:`~recbole.data.interaction.Interaction` when it is iterated. And it is also the ancestor of all other dataloader. @@ -26,50 +29,30 @@ class AbstractDataLoader(object): Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. Attributes: dataset (Dataset): The dataset of this dataloader. shuffle (bool): If ``True``, dataloader will shuffle before every epoch. - real_time (bool): If ``True``, dataloader will do data pre-processing, - such as neg-sampling and data-augmentation. pr (int): Pointer of dataloader. step (int): The increment of :attr:`pr` for each batch. batch_size (int): The max interaction number for all batch. """ dl_type = None - def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False): + def __init__(self, config, dataset, sampler, shuffle=False): self.config = config self.logger = getLogger() self.dataset = dataset - self.batch_size = batch_size - self.step = batch_size - self.dl_format = dl_format + self.sampler = sampler + self.batch_size = self.step = None self.shuffle = shuffle self.pr = 0 - self.real_time = config['real_time_process'] - if self.real_time is None: - self.real_time = True + self._init_batch_size_and_step() - self.setup() - if not self.real_time: - self.data_preprocess() - - def setup(self): - """This function can be used to deal with some problems after essential args are initialized, - such as the batch-size-adaptation when neg-sampling is needed, and so on. By default, it will do nothing. - """ - pass - - def data_preprocess(self): - """This function is used to do some data preprocess, such as pre-data-augmentation. - By default, it will do nothing. - """ - pass + def _init_batch_size_and_step(self): + """Initializing :attr:`step` and :attr:`batch_size`.""" + raise NotImplementedError('Method [init_batch_size_and_step] should be implemented') def __len__(self): return math.ceil(self.pr_end / self.step) @@ -113,13 +96,76 @@ def set_batch_size(self, batch_size): raise PermissionError('Cannot change dataloader\'s batch_size while iteration') if self.batch_size != batch_size: self.batch_size = batch_size - self.logger.warning(f'Batch size is changed to {batch_size}.') - def upgrade_batch_size(self, batch_size): - """Upgrade the batch_size of the dataloader, if input batch_size is bigger than current batch_size. - Args: - batch_size (int): the new batch_size of dataloader. - """ - if self.batch_size < batch_size: - self.set_batch_size(batch_size) +class NegSampleDataLoader(AbstractDataLoader): + """:class:`NegSampleDataLoader` is an abstract class which can sample negative examples by ratio. + It has two neg-sampling method, the one is 1-by-1 neg-sampling (pair wise), + and the other is 1-by-multi neg-sampling (point wise). + + Args: + config (Config): The config of dataloader. + dataset (Dataset): The dataset of dataloader. + sampler (Sampler): The sampler of dataloader. + shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. + """ + def __init__(self, config, dataset, sampler, shuffle=True): + super().__init__(config, dataset, sampler, shuffle=shuffle) + + def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args): + self.uid_field = dataset.uid_field + self.iid_field = dataset.iid_field + self.dl_format = dl_format + self.neg_sample_args = neg_sample_args + if self.neg_sample_args['strategy'] == 'by': + self.neg_sample_num = self.neg_sample_args['by'] + + if self.dl_format == InputType.POINTWISE: + self.times = 1 + self.neg_sample_num + self.sampling_func = self._neg_sample_by_point_wise_sampling + + self.label_field = config['LABEL_FIELD'] + dataset.set_field_property(self.label_field, FeatureType.FLOAT, FeatureSource.INTERACTION, 1) + elif self.dl_format == InputType.PAIRWISE: + self.times = self.neg_sample_num + self.sampling_func = self._neg_sample_by_pair_wise_sampling + + self.neg_prefix = config['NEG_PREFIX'] + self.neg_item_id = self.neg_prefix + self.iid_field + + columns = [self.iid_field] if dataset.item_feat is None else dataset.item_feat.columns + for item_feat_col in columns: + neg_item_feat_col = self.neg_prefix + item_feat_col + dataset.copy_field_property(neg_item_feat_col, item_feat_col) + else: + raise ValueError(f'`neg sampling by` with dl_format [{self.dl_format}] not been implemented.') + + elif self.neg_sample_args['strategy'] != 'none': + raise ValueError(f'`neg_sample_args` [{self.neg_sample_args["strategy"]}] is not supported!') + + def _neg_sampling(self, inter_feat): + if self.neg_sample_args['strategy'] == 'by': + user_ids = inter_feat[self.uid_field] + item_ids = inter_feat[self.iid_field] + neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num) + return self.sampling_func(inter_feat, neg_item_ids) + else: + return inter_feat + + def _neg_sample_by_pair_wise_sampling(self, inter_feat, neg_item_ids): + inter_feat = inter_feat.repeat(self.times) + neg_item_feat = Interaction({self.iid_field: neg_item_ids}) + neg_item_feat = self.dataset.join(neg_item_feat) + neg_item_feat.add_prefix(self.neg_prefix) + inter_feat.update(neg_item_feat) + return inter_feat + + def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_item_ids): + pos_inter_num = len(inter_feat) + new_data = inter_feat.repeat(self.times) + new_data[self.iid_field][pos_inter_num:] = neg_item_ids + new_data = self.dataset.join(new_data) + labels = torch.zeros(pos_inter_num * self.times) + labels[:pos_inter_num] = 1.0 + new_data.update(Interaction({self.label_field: labels})) + return new_data diff --git a/recbole/data/dataloader/context_dataloader.py b/recbole/data/dataloader/context_dataloader.py deleted file mode 100644 index 9ca4fc4df..000000000 --- a/recbole/data/dataloader/context_dataloader.py +++ /dev/null @@ -1,40 +0,0 @@ -# @Time : 2020/7/7 -# @Author : Yupeng Hou -# @Email : houyupeng@ruc.edu.cn - -# UPDATE -# @Time : 2020/9/9, 2020/9/16 -# @Author : Yupeng Hou, Yushuo Chen -# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn - -""" -recbole.data.dataloader.context_dataloader -################################################ -""" - -from recbole.data.dataloader.general_dataloader import GeneralDataLoader, GeneralNegSampleDataLoader, \ - GeneralFullDataLoader - - -class ContextDataLoader(GeneralDataLoader): - """:class:`ContextDataLoader` is inherit from - :class:`~recbole.data.dataloader.general_dataloader.GeneralDataLoader`, - and didn't add/change anything at all. - """ - pass - - -class ContextNegSampleDataLoader(GeneralNegSampleDataLoader): - """:class:`ContextNegSampleDataLoader` is inherit from - :class:`~recbole.data.dataloader.general_dataloader.GeneralNegSampleDataLoader`, - and didn't add/change anything at all. - """ - pass - - -class ContextFullDataLoader(GeneralFullDataLoader): - """:class:`ContextFullDataLoader` is inherit from - :class:`~recbole.data.dataloader.general_dataloader.GeneralFullDataLoader`, - and didn't add/change anything at all. - """ - pass diff --git a/recbole/data/dataloader/decisiontree_dataloader.py b/recbole/data/dataloader/decisiontree_dataloader.py deleted file mode 100644 index 996b720a8..000000000 --- a/recbole/data/dataloader/decisiontree_dataloader.py +++ /dev/null @@ -1,40 +0,0 @@ -# @Time : 2020/11/19 -# @Author : Chen Yang -# @Email : 254170321@qq.com - -# UPDATE: -# @Time : 2020/11/19 -# @Author : Chen Yang -# @Email : 254170321@qq.com - -""" -recbole.data.dataloader.decisiontree_dataloader -################################################ -""" - -from recbole.data.dataloader.general_dataloader import GeneralDataLoader, GeneralNegSampleDataLoader, \ - GeneralFullDataLoader - - -class DecisionTreeDataLoader(GeneralDataLoader): - """:class:`DecisionTreeDataLoader` is inherit from - :class:`~recbole.data.dataloader.general_dataloader.GeneralDataLoader`, - and didn't add/change anything at all. - """ - pass - - -class DecisionTreeNegSampleDataLoader(GeneralNegSampleDataLoader): - """:class:`DecisionTreeNegSampleDataLoader` is inherit from - :class:`~recbole.data.dataloader.general_dataloader.GeneralNegSampleDataLoader`, - and didn't add/change anything at all. - """ - pass - - -class DecisionTreeFullDataLoader(GeneralFullDataLoader): - """:class:`DecisionTreeFullDataLoader` is inherit from - :class:`~recbole.data.dataloader.general_dataloader.GeneralFullDataLoader`, - and didn't add/change anything at all. - """ - pass diff --git a/recbole/data/dataloader/dien_dataloader.py b/recbole/data/dataloader/dien_dataloader.py deleted file mode 100644 index 6fb06b321..000000000 --- a/recbole/data/dataloader/dien_dataloader.py +++ /dev/null @@ -1,146 +0,0 @@ -# @Time : 2021/2/25 -# @Author : Zhichao Feng -# @Email : fzcbupt@gmail.com - -# UPDATE -# @Time : 2021/3/19 -# @Author : Zhichao Feng -# @email : fzcbupt@gmail.com - -""" -recbole.data.dataloader.dien_dataloader -################################################ -""" - -import torch - -from recbole.data.dataloader.sequential_dataloader import SequentialDataLoader, SequentialNegSampleDataLoader, SequentialFullDataLoader -from recbole.data.interaction import Interaction, cat_interactions -from recbole.utils import DataLoaderType, FeatureSource, FeatureType, InputType -from recbole.sampler import SeqSampler - - -class DIENDataLoader(SequentialDataLoader): - """:class:`DIENDataLoader` is used for DIEN model. It is different from :class:`SequentialDataLoader` in - `augmentation`. It add users' negative item list to interaction. - It will do data augmentation for the origin data. And its returned data contains the following: - - - user id - - history items list - - history negative item list - - history items' interaction time list - - item to be predicted - - the interaction time of item to be predicted - - history list length - - other interaction information of item to be predicted - - Args: - config (Config): The config of dataloader. - dataset (Dataset): The dataset of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. - shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. - """ - dl_type = DataLoaderType.ORIGIN - - def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False): - - list_suffix = config['LIST_SUFFIX'] - neg_prefix = config['NEG_PREFIX'] - - self.seq_sampler = SeqSampler(dataset) - self.iid_field = dataset.iid_field - self.neg_item_list_field = neg_prefix + self.iid_field + list_suffix - self.neg_item_list = self.seq_sampler.sample_neg_sequence(dataset.inter_feat[self.iid_field]) - - super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle) - - def augmentation(self, item_list_index, target_index, item_list_length): - """Data augmentation. - - Args: - item_list_index (numpy.ndarray): the index of history items list in interaction. - target_index (numpy.ndarray): the index of items to be predicted in interaction. - item_list_length (numpy.ndarray): history list length. - - Returns: - dict: the augmented data. - """ - new_length = len(item_list_index) - new_data = self.dataset.inter_feat[target_index] - new_dict = { - self.item_list_length_field: torch.tensor(item_list_length), - } - - for field in self.dataset.inter_feat: - if field != self.uid_field: - list_field = getattr(self, f'{field}_list_field') - list_len = self.dataset.field2seqlen[list_field] - shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len - list_ftype = self.dataset.field2type[list_field] - dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64 - new_dict[list_field] = torch.zeros(shape, dtype=dtype) - - value = self.dataset.inter_feat[field] - for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): - new_dict[list_field][i][:length] = value[index] - - if field == self.iid_field: - new_dict[self.neg_item_list_field] = torch.zeros(shape, dtype=dtype) - for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): - new_dict[self.neg_item_list_field][i][:length] = self.neg_item_list[index] - - new_data.update(Interaction(new_dict)) - return new_data - - -class DIENNegSampleDataLoader(SequentialNegSampleDataLoader, DIENDataLoader): - """:class:`DIENNegSampleDataLoader` is sequential-dataloader with negative sampling for DIEN. - Like :class:`~recbole.data.dataloader.general_dataloader.GeneralNegSampleDataLoader`, for the result of every batch, - we permit that every positive interaction and its negative interaction must be in the same batch. Beside this, - when it is in the evaluation stage, and evaluator is topk-like function, we also permit that all the interactions - corresponding to each user are in the same batch and positive interactions are before negative interactions. - - Args: - config (Config): The config of dataloader. - dataset (Dataset): The dataset of dataloader. - sampler (Sampler): The sampler of dataloader. - neg_sample_args (dict): The neg_sample_args of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. - shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. - """ - - def __init__( - self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False - ): - super().__init__( - config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle - ) - - -class DIENFullDataLoader(SequentialFullDataLoader, DIENDataLoader): - """:class:`DIENFullDataLoader` is a sequential-dataloader with full sort for DIEN. In order to speed up calculation, - this dataloader would only return then user part of interactions, positive items and used items. - It would not return negative items. - - Args: - config (Config): The config of dataloader. - dataset (Dataset): The dataset of dataloader. - sampler (Sampler): The sampler of dataloader. - neg_sample_args (dict): The neg_sample_args of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. - shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. - """ - dl_type = DataLoaderType.FULL - - def __init__( - self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False - ): - super().__init__( - config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle - ) diff --git a/recbole/data/dataloader/general_dataloader.py b/recbole/data/dataloader/general_dataloader.py index 69aac242d..1d0b0d2a0 100644 --- a/recbole/data/dataloader/general_dataloader.py +++ b/recbole/data/dataloader/general_dataloader.py @@ -15,27 +15,38 @@ import numpy as np import torch -from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader -from recbole.data.dataloader.neg_sample_mixin import NegSampleMixin, NegSampleByMixin +from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader, NegSampleDataLoader from recbole.data.interaction import Interaction, cat_interactions -from recbole.utils import DataLoaderType, InputType +from recbole.utils import DataLoaderType, InputType, ModelType -class GeneralDataLoader(AbstractDataLoader): - """:class:`GeneralDataLoader` is used for general model and it just return the origin data. +class TrainDataLoader(NegSampleDataLoader): + """:class:`TrainDataLoader` is a dataloader for training. + It can generate negative interaction when :attr:`training_neg_sample_num` is not zero. + For the result of every batch, we permit that every positive interaction and its negative interaction + must be in the same batch. Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. + sampler (Sampler): The sampler of dataloader. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. """ - dl_type = DataLoaderType.ORIGIN - def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False): - super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle) + def __init__(self, config, dataset, sampler, shuffle=False): + self._set_neg_sample_args(config, dataset, config['MODEL_INPUT_TYPE'], config['train_neg_sample_args']) + super().__init__(config, dataset, sampler, shuffle=shuffle) + + def _init_batch_size_and_step(self): + batch_size = self.config['train_batch_size'] + if self.neg_sample_args['strategy'] == 'by': + batch_num = max(batch_size // self.times, 1) + new_batch_size = batch_num * self.times + self.step = batch_num + self.set_batch_size(new_batch_size) + else: + self.step = batch_size + self.set_batch_size(batch_size) @property def pr_end(self): @@ -45,131 +56,88 @@ def _shuffle(self): self.dataset.shuffle() def _next_batch_data(self): - cur_data = self.dataset[self.pr:self.pr + self.step] + cur_data = self._neg_sampling(self.dataset[self.pr:self.pr + self.step]) self.pr += self.step return cur_data -class GeneralNegSampleDataLoader(NegSampleByMixin, AbstractDataLoader): - """:class:`GeneralNegSampleDataLoader` is a general-dataloader with negative sampling. - For the result of every batch, we permit that every positive interaction and its negative interaction - must be in the same batch. Beside this, when it is in the evaluation stage, and evaluator is topk-like function, - we also permit that all the interactions corresponding to each user are in the same batch +class NegSampleEvalDataLoader(NegSampleDataLoader): + """:class:`NegSampleEvalDataLoader` is a dataloader for neg-sampling evaluation. + It is similar to :class:`TrainDataLoader` which can generate negative items, + and this dataloader also permits that all the interactions corresponding to each user are in the same batch and positive interactions are before negative interactions. Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. sampler (Sampler): The sampler of dataloader. - neg_sample_args (dict): The neg_sample_args of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. """ + def __init__(self, config, dataset, sampler, shuffle=False): + user_num = dataset.user_num + dataset.sort(by=dataset.uid_field, ascending=True) + self.uid_list = [] + start, end = dict(), dict() + for i, uid in enumerate(dataset.inter_feat[dataset.uid_field].numpy()): + if uid not in start: + self.uid_list.append(uid) + start[uid] = i + end[uid] = i + self.uid2index = np.array([None] * user_num) + self.uid2items_num = np.zeros(user_num, dtype=np.int64) + for uid in self.uid_list: + self.uid2index[uid] = slice(start[uid], end[uid] + 1) + self.uid2items_num[uid] = end[uid] - start[uid] + 1 + self.uid_list = np.array(self.uid_list) - def __init__( - self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False - ): - self.uid_field = dataset.uid_field - self.iid_field = dataset.iid_field - self.uid_list, self.uid2index, self.uid2items_num = None, None, None - - super().__init__( - config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle - ) + self._set_neg_sample_args(config, dataset, InputType.POINTWISE, config['eval_neg_sample_args']) + super().__init__(config, dataset, sampler, shuffle=shuffle) - def setup(self): - if self.user_inter_in_one_batch: - uid_field = self.dataset.uid_field - user_num = self.dataset.user_num - self.dataset.sort(by=uid_field, ascending=True) - self.uid_list = [] - start, end = dict(), dict() - for i, uid in enumerate(self.dataset.inter_feat[uid_field].numpy()): - if uid not in start: - self.uid_list.append(uid) - start[uid] = i - end[uid] = i - self.uid2index = np.array([None] * user_num) - self.uid2items_num = np.zeros(user_num, dtype=np.int64) - for uid in self.uid_list: - self.uid2index[uid] = slice(start[uid], end[uid] + 1) - self.uid2items_num[uid] = end[uid] - start[uid] + 1 - self.uid_list = np.array(self.uid_list) - self._batch_size_adaptation() - - def _batch_size_adaptation(self): - if self.user_inter_in_one_batch: + def _init_batch_size_and_step(self): + batch_size = self.config['eval_batch_size'] + if self.neg_sample_args['strategy'] == 'by': inters_num = sorted(self.uid2items_num * self.times, reverse=True) batch_num = 1 new_batch_size = inters_num[0] for i in range(1, len(inters_num)): - if new_batch_size + inters_num[i] > self.batch_size: + if new_batch_size + inters_num[i] > batch_size: break batch_num = i + 1 new_batch_size += inters_num[i] self.step = batch_num - self.upgrade_batch_size(new_batch_size) + self.set_batch_size(new_batch_size) else: - batch_num = max(self.batch_size // self.times, 1) - new_batch_size = batch_num * self.times - self.step = batch_num - self.upgrade_batch_size(new_batch_size) + self.step = batch_size + self.set_batch_size(batch_size) @property def pr_end(self): - if self.user_inter_in_one_batch: - return len(self.uid_list) - else: - return len(self.dataset) + return len(self.uid_list) def _shuffle(self): - if self.user_inter_in_one_batch: - np.random.shuffle(self.uid_list) - else: - self.dataset.shuffle() + self.logger.warnning('NegSampleEvalDataLoader can\'t shuffle') def _next_batch_data(self): - if self.user_inter_in_one_batch: - uid_list = self.uid_list[self.pr:self.pr + self.step] - data_list = [] - for uid in uid_list: - index = self.uid2index[uid] - data_list.append(self._neg_sampling(self.dataset[index])) - cur_data = cat_interactions(data_list) - pos_len_list = self.uid2items_num[uid_list] - user_len_list = pos_len_list * self.times - cur_data.set_additional_info(list(pos_len_list), list(user_len_list)) - self.pr += self.step - return cur_data - else: - cur_data = self._neg_sampling(self.dataset[self.pr:self.pr + self.step]) - self.pr += self.step - return cur_data - - def _neg_sampling(self, inter_feat): - uids = inter_feat[self.uid_field] - neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by) - return self.sampling_func(inter_feat, neg_iids) - - def _neg_sample_by_pair_wise_sampling(self, inter_feat, neg_iids): - inter_feat = inter_feat.repeat(self.times) - neg_item_feat = Interaction({self.iid_field: neg_iids}) - neg_item_feat = self.dataset.join(neg_item_feat) - neg_item_feat.add_prefix(self.neg_prefix) - inter_feat.update(neg_item_feat) - return inter_feat - - def _neg_sample_by_point_wise_sampling(self, inter_feat, neg_iids): - pos_inter_num = len(inter_feat) - new_data = inter_feat.repeat(self.times) - new_data[self.iid_field][pos_inter_num:] = neg_iids - new_data = self.dataset.join(new_data) - labels = torch.zeros(pos_inter_num * self.times) - labels[:pos_inter_num] = 1.0 - new_data.update(Interaction({self.label_field: labels})) - return new_data + uid_list = self.uid_list[self.pr:self.pr + self.step] + data_list = [] + idx_list = [] + positive_u = [] + positive_i = torch.tensor([], dtype=torch.int64) + + for idx, uid in enumerate(uid_list): + index = self.uid2index[uid] + data_list.append(self._neg_sampling(self.dataset[index])) + idx_list += [idx for i in range(self.uid2items_num[uid] * self.times)] + positive_u += [idx for i in range(self.uid2items_num[uid])] + positive_i = torch.cat((positive_i, self.dataset[index][self.iid_field]), 0) + + cur_data = cat_interactions(data_list) + idx_list = torch.from_numpy(np.array(idx_list)) + positive_u = torch.from_numpy(np.array(positive_u)) + + self.pr += self.step + + return cur_data, idx_list, positive_u, positive_i def get_pos_len_list(self): """ @@ -186,8 +154,8 @@ def get_user_len_list(self): return self.uid2items_num[self.uid_list] * self.times -class GeneralFullDataLoader(NegSampleMixin, AbstractDataLoader): - """:class:`GeneralFullDataLoader` is a general-dataloader with full sort. In order to speed up calculation, +class FullSortEvalDataLoader(AbstractDataLoader): + """:class:`FullSortEvalDataLoader` is a dataloader for full-sort evaluation. In order to speed up calculation, this dataloader would only return then user part of interactions, positive items and used items. It would not return negative items. @@ -195,101 +163,101 @@ class GeneralFullDataLoader(NegSampleMixin, AbstractDataLoader): config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. sampler (Sampler): The sampler of dataloader. - neg_sample_args (dict): The neg_sample_args of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. """ dl_type = DataLoaderType.FULL - def __init__( - self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False - ): - if neg_sample_args['strategy'] != 'full': - raise ValueError('neg_sample strategy in GeneralFullDataLoader() should be `full`') - - uid_field = dataset.uid_field - iid_field = dataset.iid_field - user_num = dataset.user_num - self.uid_list = [] - self.uid2items_num = np.zeros(user_num, dtype=np.int64) - self.uid2swap_idx = np.array([None] * user_num) - self.uid2rev_swap_idx = np.array([None] * user_num) - self.uid2history_item = np.array([None] * user_num) - - dataset.sort(by=uid_field, ascending=True) - last_uid = None - positive_item = set() - uid2used_item = sampler.used_ids - for uid, iid in zip(dataset.inter_feat[uid_field].numpy(), dataset.inter_feat[iid_field].numpy()): - if uid != last_uid: - self._set_user_property(last_uid, uid2used_item[last_uid], positive_item) - last_uid = uid - self.uid_list.append(uid) - positive_item = set() - positive_item.add(iid) - self._set_user_property(last_uid, uid2used_item[last_uid], positive_item) - self.uid_list = torch.tensor(self.uid_list, dtype=torch.int64) - self.user_df = dataset.join(Interaction({uid_field: self.uid_list})) + def __init__(self, config, dataset, sampler, shuffle=False): + self.uid_field = dataset.uid_field + self.iid_field = dataset.iid_field + self.is_sequential = config['MODEL_TYPE'] == ModelType.SEQUENTIAL + if not self.is_sequential: + user_num = dataset.user_num + self.uid_list = [] + self.uid2items_num = np.zeros(user_num, dtype=np.int64) + self.uid2positive_item = np.array([None] * user_num) + self.uid2history_item = np.array([None] * user_num) + + dataset.sort(by=self.uid_field, ascending=True) + last_uid = None + positive_item = set() + uid2used_item = sampler.used_ids + for uid, iid in zip(dataset.inter_feat[self.uid_field].numpy(), dataset.inter_feat[self.iid_field].numpy()): + if uid != last_uid: + self._set_user_property(last_uid, uid2used_item[last_uid], positive_item) + last_uid = uid + self.uid_list.append(uid) + positive_item = set() + positive_item.add(iid) + self._set_user_property(last_uid, uid2used_item[last_uid], positive_item) + self.uid_list = torch.tensor(self.uid_list, dtype=torch.int64) + self.user_df = dataset.join(Interaction({self.uid_field: self.uid_list})) - super().__init__( - config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle - ) + super().__init__(config, dataset, sampler, shuffle=shuffle) def _set_user_property(self, uid, used_item, positive_item): if uid is None: return history_item = used_item - positive_item - positive_item_num = len(positive_item) - self.uid2items_num[uid] = positive_item_num - swap_idx = torch.tensor(sorted(set(range(positive_item_num)) ^ positive_item)) - self.uid2swap_idx[uid] = swap_idx - self.uid2rev_swap_idx[uid] = swap_idx.flip(0) + self.uid2positive_item[uid] = torch.tensor(list(positive_item), dtype=torch.int64) + self.uid2items_num[uid] = len(positive_item) self.uid2history_item[uid] = torch.tensor(list(history_item), dtype=torch.int64) - def _batch_size_adaptation(self): - batch_num = max(self.batch_size // self.dataset.item_num, 1) - new_batch_size = batch_num * self.dataset.item_num - self.step = batch_num - self.upgrade_batch_size(new_batch_size) + def _init_batch_size_and_step(self): + batch_size = self.config['eval_batch_size'] + if not self.is_sequential: + batch_num = max(batch_size // self.dataset.item_num, 1) + new_batch_size = batch_num * self.dataset.item_num + self.step = batch_num + self.set_batch_size(new_batch_size) + else: + self.step = batch_size + self.set_batch_size(batch_size) @property def pr_end(self): - return len(self.uid_list) + if not self.is_sequential: + return len(self.uid_list) + else: + return len(self.dataset) def _shuffle(self): - self.logger.warnning('GeneralFullDataLoader can\'t shuffle') + self.logger.warnning('FullSortEvalDataLoader can\'t shuffle') def _next_batch_data(self): - user_df = self.user_df[self.pr:self.pr + self.step] - cur_data = self._neg_sampling(user_df) - self.pr += self.step - return cur_data + if not self.is_sequential: + user_df = self.user_df[self.pr:self.pr + self.step] + uid_list = list(user_df[self.uid_field]) + + history_item = self.uid2history_item[uid_list] + positive_item = self.uid2positive_item[uid_list] - def _neg_sampling(self, user_df): - uid_list = list(user_df[self.dataset.uid_field]) - pos_len_list = self.uid2items_num[uid_list] - user_len_list = np.full(len(uid_list), self.dataset.item_num) - user_df.set_additional_info(pos_len_list, user_len_list) + history_u = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) + history_i = torch.cat(list(history_item)) - history_item = self.uid2history_item[uid_list] - history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) - history_col = torch.cat(list(history_item)) + positive_u = torch.cat([torch.full_like(pos_iid, i) for i, pos_iid in enumerate(positive_item)]) + positive_i = torch.cat(list(positive_item)) - swap_idx = self.uid2swap_idx[uid_list] - rev_swap_idx = self.uid2rev_swap_idx[uid_list] - swap_row = torch.cat([torch.full_like(swap, i) for i, swap in enumerate(swap_idx)]) - swap_col_after = torch.cat(list(swap_idx)) - swap_col_before = torch.cat(list(rev_swap_idx)) - return user_df, (history_row, history_col), swap_row, swap_col_after, swap_col_before + self.pr += self.step + return user_df, (history_u, history_i), positive_u, positive_i + else: + interaction = self.dataset[self.pr:self.pr + self.step] + inter_num = len(interaction) + positive_u = torch.arange(inter_num) + positive_i = interaction[self.iid_field] + + self.pr += self.step + return interaction, None, positive_u, positive_i def get_pos_len_list(self): """ Returns: numpy.ndarray: Number of positive item for each user in a training/evaluating epoch. """ - return self.uid2items_num[self.uid_list] + if not self.is_sequential: + return self.uid2items_num[self.uid_list] + else: + return np.ones(self.pr_end, dtype=np.int64) def get_user_len_list(self): """ diff --git a/recbole/data/dataloader/knowledge_dataloader.py b/recbole/data/dataloader/knowledge_dataloader.py index 6b6bb00ac..db09c2bde 100644 --- a/recbole/data/dataloader/knowledge_dataloader.py +++ b/recbole/data/dataloader/knowledge_dataloader.py @@ -11,8 +11,8 @@ recbole.data.dataloader.knowledge_dataloader ################################################ """ - -from recbole.data.dataloader import AbstractDataLoader, GeneralNegSampleDataLoader +from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader +from recbole.data.dataloader.general_dataloader import TrainDataLoader from recbole.data.interaction import Interaction from recbole.utils import InputType, KGDataLoaderState @@ -25,9 +25,6 @@ class KGDataLoader(AbstractDataLoader): config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. sampler (KGSampler): The knowledge graph sampler of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.InputType.PAIRWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. Attributes: @@ -35,8 +32,11 @@ class KGDataLoader(AbstractDataLoader): However, in :class:`KGDataLoader`, it's guaranteed to be ``True``. """ - def __init__(self, config, dataset, sampler, batch_size=1, dl_format=InputType.PAIRWISE, shuffle=False): - self.sampler = sampler + def __init__(self, config, dataset, sampler, shuffle=False): + if shuffle is False: + shuffle = True + self.logger.warning('kg based dataloader must shuffle the data') + self.neg_sample_num = 1 self.neg_prefix = config['NEG_PREFIX'] @@ -47,15 +47,12 @@ def __init__(self, config, dataset, sampler, batch_size=1, dl_format=InputType.P self.neg_tid_field = self.neg_prefix + self.tid_field dataset.copy_field_property(self.neg_tid_field, self.tid_field) - super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle) + super().__init__(config, dataset, sampler, shuffle=shuffle) - def setup(self): - """Make sure that the :attr:`shuffle` is True. If :attr:`shuffle` is False, it will be changed to True - and give a warning to user. - """ - if self.shuffle is False: - self.shuffle = True - self.logger.warning('kg based dataloader must shuffle the data') + def _init_batch_size_and_step(self): + batch_size = self.config['train_batch_size'] + self.step = batch_size + self.set_batch_size(batch_size) @property def pr_end(self): @@ -65,16 +62,13 @@ def _shuffle(self): self.dataset.kg_feat.shuffle() def _next_batch_data(self): - cur_data = self._neg_sampling(self.dataset.kg_feat[self.pr:self.pr + self.step]) + cur_data = self.dataset.kg_feat[self.pr:self.pr + self.step] + head_ids = cur_data[self.hid_field] + neg_tail_ids = self.sampler.sample_by_entity_ids(head_ids, self.neg_sample_num) + cur_data.update(Interaction({self.neg_tid_field: neg_tail_ids})) self.pr += self.step return cur_data - def _neg_sampling(self, kg_feat): - hids = kg_feat[self.hid_field] - neg_tids = self.sampler.sample_by_entity_ids(hids, self.neg_sample_num) - kg_feat.update(Interaction({self.neg_tid_field: neg_tids})) - return kg_feat - class KnowledgeBasedDataLoader(AbstractDataLoader): """:class:`KnowledgeBasedDataLoader` is used for knowledge based model. @@ -88,21 +82,18 @@ class KnowledgeBasedDataLoader(AbstractDataLoader): dataset (Dataset): The dataset of dataloader. sampler (Sampler): The sampler of dataloader. kg_sampler (KGSampler): The knowledge graph sampler of dataloader. - neg_sample_args (dict): The neg_sample_args of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. Attributes: - state (KGDataLoaderState): + state (KGDataLoaderState): This dataloader has three states: - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RS` - :obj:`~recbole.utils.enum_type.KGDataLoaderState.KG` - :obj:`~recbole.utils.enum_type.KGDataLoaderState.RSKG` - In the first state, this dataloader would only return the triplets with negative examples in a knowledge graph. + In the first state, this dataloader would only return the triplets with negative + examples in a knowledge graph. In the second state, this dataloader would only return the user-item interaction. @@ -110,37 +101,20 @@ class KnowledgeBasedDataLoader(AbstractDataLoader): and user-item interaction information. """ - def __init__( - self, - config, - dataset, - sampler, - kg_sampler, - neg_sample_args, - batch_size=1, - dl_format=InputType.POINTWISE, - shuffle=False - ): + def __init__(self, config, dataset, sampler, kg_sampler, shuffle=False): # using sampler - self.general_dataloader = GeneralNegSampleDataLoader( - config=config, - dataset=dataset, - sampler=sampler, - neg_sample_args=neg_sample_args, - batch_size=batch_size, - dl_format=dl_format, - shuffle=shuffle - ) - - # using kg_sampler and dl_format is pairwise - self.kg_dataloader = KGDataLoader( - config, dataset, kg_sampler, batch_size=batch_size, dl_format=InputType.PAIRWISE, shuffle=True - ) + self.general_dataloader = TrainDataLoader(config, dataset, sampler, shuffle=shuffle) + + # using kg_sampler + self.kg_dataloader = KGDataLoader(config, dataset, kg_sampler, shuffle=True) self.state = None - super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle) + super().__init__(config, dataset, sampler, shuffle=shuffle) + + def _init_batch_size_and_step(self): + pass def __iter__(self): if self.state is None: diff --git a/recbole/data/dataloader/neg_sample_mixin.py b/recbole/data/dataloader/neg_sample_mixin.py deleted file mode 100644 index e21d614ac..000000000 --- a/recbole/data/dataloader/neg_sample_mixin.py +++ /dev/null @@ -1,140 +0,0 @@ -# @Time : 2020/7/7 -# @Author : Yupeng Hou -# @Email : houyupeng@ruc.edu.cn - -# UPDATE -# @Time : 2020/9/9, 2020/9/17 -# @Author : Yupeng Hou, Yushuo Chen -# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn - -""" -recbole.data.dataloader.neg_sample_mixin -################################################ -""" - -from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader -from recbole.utils import DataLoaderType, EvaluatorType, FeatureSource, FeatureType, InputType - - -class NegSampleMixin(AbstractDataLoader): - """:class:`NegSampleMixin` is a abstract class, all dataloaders that need negative sampling should inherit - this class. This class provides some necessary parameters and method for negative sampling, such as - :attr:`neg_sample_args` and :meth:`_neg_sampling()` and so on. - - Args: - config (Config): The config of dataloader. - dataset (Dataset): The dataset of dataloader. - sampler (Sampler): The sampler of dataloader. - neg_sample_args (dict): The neg_sample_args of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.InputType.POINTWISE`. - shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. - """ - dl_type = DataLoaderType.NEGSAMPLE - - def __init__( - self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False - ): - if neg_sample_args['strategy'] not in ['by', 'full']: - raise ValueError(f"Neg_sample strategy [{neg_sample_args['strategy']}] has not been implemented.") - - self.sampler = sampler - self.neg_sample_args = neg_sample_args - - super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle) - - def setup(self): - """Do batch size adaptation. - """ - self._batch_size_adaptation() - - def _batch_size_adaptation(self): - """Adjust the batch size to ensure that each positive and negative interaction can be in a batch. - """ - raise NotImplementedError('Method [batch_size_adaptation] should be implemented.') - - def _neg_sampling(self, inter_feat): - """ - Args: - inter_feat: The origin user-item interaction table. - - Returns: - The user-item interaction table with negative example. - """ - raise NotImplementedError('Method [neg_sampling] should be implemented.') - - def get_pos_len_list(self): - """ - Returns: - numpy.ndarray: Number of positive item for each user in a training/evaluating epoch. - """ - raise NotImplementedError('Method [get_pos_len_list] should be implemented.') - - def get_user_len_list(self): - """ - Returns: - numpy.ndarray: Number of all item for each user in a training/evaluating epoch. - """ - raise NotImplementedError('Method [get_user_len_list] should be implemented.') - - -class NegSampleByMixin(NegSampleMixin): - """:class:`NegSampleByMixin` is an abstract class which can sample negative examples by ratio. - It has two neg-sampling method, the one is 1-by-1 neg-sampling (pair wise), - and the other is 1-by-multi neg-sampling (point wise). - - Args: - config (Config): The config of dataloader. - dataset (Dataset): The dataset of dataloader. - sampler (Sampler): The sampler of dataloader. - neg_sample_args (dict): The neg_sample_args of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. - shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. - """ - - def __init__( - self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False - ): - if neg_sample_args['strategy'] != 'by': - raise ValueError('neg_sample strategy in GeneralInteractionBasedDataLoader() should be `by`') - - self.user_inter_in_one_batch = (sampler.phase != 'train') and (config['eval_type'] != EvaluatorType.INDIVIDUAL) - self.neg_sample_by = neg_sample_args['by'] - - if dl_format == InputType.POINTWISE: - self.times = 1 + self.neg_sample_by - self.sampling_func = self._neg_sample_by_point_wise_sampling - - self.label_field = config['LABEL_FIELD'] - dataset.set_field_property(self.label_field, FeatureType.FLOAT, FeatureSource.INTERACTION, 1) - elif dl_format == InputType.PAIRWISE: - self.times = self.neg_sample_by - self.sampling_func = self._neg_sample_by_pair_wise_sampling - - self.neg_prefix = config['NEG_PREFIX'] - iid_field = config['ITEM_ID_FIELD'] - self.neg_item_id = self.neg_prefix + iid_field - - columns = [iid_field] if dataset.item_feat is None else dataset.item_feat.columns - for item_feat_col in columns: - neg_item_feat_col = self.neg_prefix + item_feat_col - dataset.copy_field_property(neg_item_feat_col, item_feat_col) - else: - raise ValueError(f'`neg sampling by` with dl_format [{dl_format}] not been implemented.') - - super().__init__( - config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle - ) - - def _neg_sample_by_pair_wise_sampling(self, *args): - """Pair-wise sampling. - """ - raise NotImplementedError('Method [neg_sample_by_pair_wise_sampling] should be implemented.') - - def _neg_sample_by_point_wise_sampling(self, *args): - """Point-wise sampling. - """ - raise NotImplementedError('Method [neg_sample_by_point_wise_sampling] should be implemented.') diff --git a/recbole/data/dataloader/sequential_dataloader.py b/recbole/data/dataloader/sequential_dataloader.py deleted file mode 100644 index ad962a329..000000000 --- a/recbole/data/dataloader/sequential_dataloader.py +++ /dev/null @@ -1,294 +0,0 @@ -# @Time : 2020/7/7 -# @Author : Yupeng Hou -# @Email : houyupeng@ruc.edu.cn - -# UPDATE -# @Time : 2020/10/6, 2020/9/17 -# @Author : Yupeng Hou, Yushuo Chen -# @email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn - -""" -recbole.data.dataloader.sequential_dataloader -################################################ -""" - -import numpy as np -import torch - -from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader -from recbole.data.dataloader.neg_sample_mixin import NegSampleByMixin, NegSampleMixin -from recbole.data.interaction import Interaction, cat_interactions -from recbole.utils import DataLoaderType, FeatureSource, FeatureType, InputType - - -class SequentialDataLoader(AbstractDataLoader): - """:class:`SequentialDataLoader` is used for sequential model. It will do data augmentation for the origin data. - And its returned data contains the following: - - - user id - - history items list - - history items' interaction time list - - item to be predicted - - the interaction time of item to be predicted - - history list length - - other interaction information of item to be predicted - - Args: - config (Config): The config of dataloader. - dataset (Dataset): The dataset of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. - shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. - """ - dl_type = DataLoaderType.ORIGIN - - def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False): - self.uid_field = dataset.uid_field - self.iid_field = dataset.iid_field - self.time_field = dataset.time_field - self.max_item_list_len = config['MAX_ITEM_LIST_LENGTH'] - - list_suffix = config['LIST_SUFFIX'] - for field in dataset.inter_feat: - if field != self.uid_field: - list_field = field + list_suffix - setattr(self, f'{field}_list_field', list_field) - ftype = dataset.field2type[field] - - if ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ]: - list_ftype = FeatureType.TOKEN_SEQ - else: - list_ftype = FeatureType.FLOAT_SEQ - - if ftype in [FeatureType.TOKEN_SEQ, FeatureType.FLOAT_SEQ]: - list_len = (self.max_item_list_len, dataset.field2seqlen[field]) - else: - list_len = self.max_item_list_len - - dataset.set_field_property(list_field, list_ftype, FeatureSource.INTERACTION, list_len) - - self.item_list_length_field = config['ITEM_LIST_LENGTH_FIELD'] - dataset.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1) - - self.uid_list = dataset.uid_list - self.item_list_index = dataset.item_list_index - self.target_index = dataset.target_index - self.item_list_length = dataset.item_list_length - self.pre_processed_data = None - - super().__init__(config, dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle) - - def data_preprocess(self): - """Do data augmentation before training/evaluation. - """ - self.pre_processed_data = self.augmentation(self.item_list_index, self.target_index, self.item_list_length) - - @property - def pr_end(self): - return len(self.uid_list) - - def _shuffle(self): - if self.real_time: - new_index = torch.randperm(self.pr_end) - self.uid_list = self.uid_list[new_index] - self.item_list_index = self.item_list_index[new_index] - self.target_index = self.target_index[new_index] - self.item_list_length = self.item_list_length[new_index] - else: - self.pre_processed_data.shuffle() - - def _next_batch_data(self): - cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step)) - self.pr += self.step - return cur_data - - def _get_processed_data(self, index): - if self.real_time: - cur_data = self.augmentation( - self.item_list_index[index], self.target_index[index], self.item_list_length[index] - ) - else: - cur_data = self.pre_processed_data[index] - return cur_data - - def augmentation(self, item_list_index, target_index, item_list_length): - """Data augmentation. - - Args: - item_list_index (numpy.ndarray): the index of history items list in interaction. - target_index (numpy.ndarray): the index of items to be predicted in interaction. - item_list_length (numpy.ndarray): history list length. - - Returns: - dict: the augmented data. - """ - new_length = len(item_list_index) - new_data = self.dataset.inter_feat[target_index] - new_dict = { - self.item_list_length_field: torch.tensor(item_list_length), - } - - for field in self.dataset.inter_feat: - if field != self.uid_field: - list_field = getattr(self, f'{field}_list_field') - list_len = self.dataset.field2seqlen[list_field] - shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len - list_ftype = self.dataset.field2type[list_field] - dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64 - new_dict[list_field] = torch.zeros(shape, dtype=dtype) - - value = self.dataset.inter_feat[field] - for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): - new_dict[list_field][i][:length] = value[index] - - new_data.update(Interaction(new_dict)) - return new_data - - -class SequentialNegSampleDataLoader(NegSampleByMixin, SequentialDataLoader): - """:class:`SequentialNegSampleDataLoader` is sequential-dataloader with negative sampling. - Like :class:`~recbole.data.dataloader.general_dataloader.GeneralNegSampleDataLoader`, for the result of every batch, - we permit that every positive interaction and its negative interaction must be in the same batch. Beside this, - when it is in the evaluation stage, and evaluator is topk-like function, we also permit that all the interactions - corresponding to each user are in the same batch and positive interactions are before negative interactions. - - Args: - config (Config): The config of dataloader. - dataset (Dataset): The dataset of dataloader. - sampler (Sampler): The sampler of dataloader. - neg_sample_args (dict): The neg_sample_args of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. - shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. - """ - - def __init__( - self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False - ): - super().__init__( - config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle - ) - - def _batch_size_adaptation(self): - batch_num = max(self.batch_size // self.times, 1) - new_batch_size = batch_num * self.times - self.step = batch_num - self.upgrade_batch_size(new_batch_size) - - def _next_batch_data(self): - cur_data = self._get_processed_data(slice(self.pr, self.pr + self.step)) - cur_data = self._neg_sampling(cur_data) - self.pr += self.step - - if self.user_inter_in_one_batch: - cur_data_len = len(cur_data[self.uid_field]) - pos_len_list = np.ones(cur_data_len // self.times, dtype=np.int64) - user_len_list = pos_len_list * self.times - cur_data.set_additional_info(list(pos_len_list), list(user_len_list)) - return cur_data - - def _neg_sampling(self, data): - if self.user_inter_in_one_batch: - data_len = len(data[self.uid_field]) - data_list = [] - for i in range(data_len): - uids = data[self.uid_field][i:i + 1] - neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by) - cur_data = data[i:i + 1] - data_list.append(self.sampling_func(cur_data, neg_iids)) - return cat_interactions(data_list) - else: - uids = data[self.uid_field] - neg_iids = self.sampler.sample_by_user_ids(uids, self.neg_sample_by) - return self.sampling_func(data, neg_iids) - - def _neg_sample_by_pair_wise_sampling(self, data, neg_iids): - new_data = data.repeat(self.times) - new_data.update(Interaction({self.neg_item_id: neg_iids})) - return new_data - - def _neg_sample_by_point_wise_sampling(self, data, neg_iids): - pos_inter_num = len(data) - new_data = data.repeat(self.times) - new_data[self.iid_field][pos_inter_num:] = neg_iids - labels = torch.zeros(pos_inter_num * self.times) - labels[:pos_inter_num] = 1.0 - new_data.update(Interaction({self.label_field: labels})) - return new_data - - def get_pos_len_list(self): - """ - Returns: - numpy.ndarray: Number of positive item for each user in a training/evaluating epoch. - """ - return np.ones(self.pr_end, dtype=np.int64) - - def get_user_len_list(self): - """ - Returns: - numpy.ndarray: Number of all item for each user in a training/evaluating epoch. - """ - return np.full(self.pr_end, self.times) - - -class SequentialFullDataLoader(NegSampleMixin, SequentialDataLoader): - """:class:`SequentialFullDataLoader` is a sequential-dataloader with full sort. In order to speed up calculation, - this dataloader would only return then user part of interactions, positive items and used items. - It would not return negative items. - - Args: - config (Config): The config of dataloader. - dataset (Dataset): The dataset of dataloader. - sampler (Sampler): The sampler of dataloader. - neg_sample_args (dict): The neg_sample_args of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. - shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. - """ - dl_type = DataLoaderType.FULL - - def __init__( - self, config, dataset, sampler, neg_sample_args, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False - ): - super().__init__( - config, dataset, sampler, neg_sample_args, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle - ) - - def _batch_size_adaptation(self): - pass - - def _neg_sampling(self, inter_feat): - pass - - def _shuffle(self): - self.logger.warnning('SequentialFullDataLoader can\'t shuffle') - - def _next_batch_data(self): - interaction = super()._next_batch_data() - inter_num = len(interaction) - pos_len_list = np.ones(inter_num, dtype=np.int64) - user_len_list = np.full(inter_num, self.dataset.item_num) - interaction.set_additional_info(pos_len_list, user_len_list) - scores_row = torch.arange(inter_num).repeat(2) - padding_idx = torch.zeros(inter_num, dtype=torch.int64) - positive_idx = interaction[self.iid_field] - scores_col_after = torch.cat((padding_idx, positive_idx)) - scores_col_before = torch.cat((positive_idx, padding_idx)) - return interaction, None, scores_row, scores_col_after, scores_col_before - - def get_pos_len_list(self): - """ - Returns: - numpy.ndarray or list: Number of positive item for each user in a training/evaluating epoch. - """ - return np.ones(self.pr_end, dtype=np.int64) - - def get_user_len_list(self): - """ - Returns: - numpy.ndarray: Number of all item for each user in a training/evaluating epoch. - """ - return np.full(self.pr_end, self.dataset.item_num) diff --git a/recbole/data/dataloader/user_dataloader.py b/recbole/data/dataloader/user_dataloader.py index 2d2fd62a0..767be29a7 100644 --- a/recbole/data/dataloader/user_dataloader.py +++ b/recbole/data/dataloader/user_dataloader.py @@ -13,7 +13,7 @@ """ import torch -from recbole.data.dataloader import AbstractDataLoader +from recbole.data.dataloader.abstract_dataloader import AbstractDataLoader from recbole.data.interaction import Interaction from recbole.utils.enum_type import DataLoaderType, InputType @@ -24,30 +24,30 @@ class UserDataLoader(AbstractDataLoader): Args: config (Config): The config of dataloader. dataset (Dataset): The dataset of dataloader. - batch_size (int, optional): The batch_size of dataloader. Defaults to ``1``. - dl_format (InputType, optional): The input type of dataloader. Defaults to - :obj:`~recbole.utils.enum_type.InputType.POINTWISE`. + sampler (Sampler): The sampler of dataloader. shuffle (bool, optional): Whether the dataloader will be shuffle after a round. Defaults to ``False``. Attributes: shuffle (bool): Whether the dataloader will be shuffle after a round. However, in :class:`UserDataLoader`, it's guaranteed to be ``True``. """ + dl_type = DataLoaderType.ORIGIN - def __init__(self, config, dataset, batch_size=1, dl_format=InputType.POINTWISE, shuffle=False): + def __init__(self, config, dataset, sampler, shuffle=False): + if shuffle is False: + shuffle = True + self.logger.warning('UserDataLoader must shuffle the data.') + self.uid_field = dataset.uid_field self.user_list = Interaction({self.uid_field: torch.arange(dataset.user_num)}) - super().__init__(config=config, dataset=dataset, batch_size=batch_size, dl_format=dl_format, shuffle=shuffle) + super().__init__(config, dataset, sampler, shuffle=shuffle) - def setup(self): - """Make sure that the :attr:`shuffle` is True. If :attr:`shuffle` is False, it will be changed to True - and give a warning to user. - """ - if self.shuffle is False: - self.shuffle = True - self.logger.warning('UserDataLoader must shuffle the data') + def _init_batch_size_and_step(self): + batch_size = self.config['train_batch_size'] + self.step = batch_size + self.set_batch_size(batch_size) @property def pr_end(self): diff --git a/recbole/data/dataset/customized_dataset.py b/recbole/data/dataset/customized_dataset.py index d1e34bab2..3dd7f0546 100644 --- a/recbole/data/dataset/customized_dataset.py +++ b/recbole/data/dataset/customized_dataset.py @@ -2,6 +2,11 @@ # @Author : Yupeng Hou # @Email : houyupeng@ruc.edu.cn +# UPDATE +# @Time : 2021/7/9 +# @Author : Yupeng Hou +# @Email : houyupeng@ruc.edu.cn + """ recbole.data.customized_dataset ################################## @@ -11,7 +16,13 @@ Customized datasets named ``[Model Name]Dataset`` can be automatically called. """ -from recbole.data.dataset import Kg_Seq_Dataset +import numpy as np +import torch + +from recbole.data.dataset import Kg_Seq_Dataset, SequentialDataset +from recbole.data.interaction import Interaction +from recbole.sampler import SeqSampler +from recbole.utils.enum_type import FeatureType class GRU4RecKGDataset(Kg_Seq_Dataset): @@ -24,3 +35,100 @@ class KSRDataset(Kg_Seq_Dataset): def __init__(self, config): super().__init__(config) + + +class DIENDataset(SequentialDataset): + """:class:`DIENDataset` is based on :class:`~recbole.data.dataset.sequential_dataset.SequentialDataset`. + It is different from :class:`SequentialDataset` in `data_augmentation`. + It add users' negative item list to interaction. + + The original version of sampling negative item list is implemented by Zhichao Feng (fzcbupt@gmail.com) in 2021/2/25, + and he updated the codes in 2021/3/19. In 2021/7/9, Yupeng refactored SequentialDataset & SequentialDataLoader, + then refactored DIENDataset, either. + + Attributes: + augmentation (bool): Whether the interactions should be augmented in RecBole. + seq_sample (recbole.sampler.SeqSampler): A sampler used to sample negative item sequence. + neg_item_list_field (str): Field name for negative item sequence. + neg_item_list (torch.tensor): all users' negative item history sequence. + """ + def __init__(self, config): + super().__init__(config) + + list_suffix = config['LIST_SUFFIX'] + neg_prefix = config['NEG_PREFIX'] + self.seq_sampler = SeqSampler(self) + self.neg_item_list_field = neg_prefix + self.iid_field + list_suffix + self.neg_item_list = self.seq_sampler.sample_neg_sequence(self.inter_feat[self.iid_field]) + + def data_augmentation(self): + """Augmentation processing for sequential dataset. + + E.g., ``u1`` has purchase sequence ````, + then after augmentation, we will generate three cases. + + ``u1, | i2`` + + (Which means given user_id ``u1`` and item_seq ````, + we need to predict the next item ``i2``.) + + The other cases are below: + + ``u1, | i3`` + + ``u1, | i4`` + """ + self.logger.debug('data_augmentation') + + self._aug_presets() + + self._check_field('uid_field', 'time_field') + max_item_list_len = self.config['MAX_ITEM_LIST_LENGTH'] + self.sort(by=[self.uid_field, self.time_field], ascending=True) + last_uid = None + uid_list, item_list_index, target_index, item_list_length = [], [], [], [] + seq_start = 0 + for i, uid in enumerate(self.inter_feat[self.uid_field].numpy()): + if last_uid != uid: + last_uid = uid + seq_start = i + else: + if i - seq_start > max_item_list_len: + seq_start += 1 + uid_list.append(uid) + item_list_index.append(slice(seq_start, i)) + target_index.append(i) + item_list_length.append(i - seq_start) + + uid_list = np.array(uid_list) + item_list_index = np.array(item_list_index) + target_index = np.array(target_index) + item_list_length = np.array(item_list_length, dtype=np.int64) + + new_length = len(item_list_index) + new_data = self.inter_feat[target_index] + new_dict = { + self.item_list_length_field: torch.tensor(item_list_length), + } + + for field in self.inter_feat: + if field != self.uid_field: + list_field = getattr(self, f'{field}_list_field') + list_len = self.field2seqlen[list_field] + shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len + list_ftype = self.field2type[list_field] + dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64 + new_dict[list_field] = torch.zeros(shape, dtype=dtype) + + value = self.inter_feat[field] + for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): + new_dict[list_field][i][:length] = value[index] + + # DIEN + if field == self.iid_field: + new_dict[self.neg_item_list_field] = torch.zeros(shape, dtype=dtype) + for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): + new_dict[self.neg_item_list_field][i][:length] = self.neg_item_list[index] + + new_data.update(Interaction(new_dict)) + self.inter_feat = new_data diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index a0e4a82ef..63b7751cc 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -3,9 +3,9 @@ # @Email : houyupeng@ruc.edu.cn # UPDATE: -# @Time : 2020/10/28 2020/10/13, 2020/11/10 +# @Time : 2021/7/11 2021/7/1, 2020/11/10 # @Author : Yupeng Hou, Xingyu Pan, Yushuo Chen -# @Email : houyupeng@ruc.edu.cn, panxy@ruc.edu.cn, chenyushuo@ruc.edu.cn +# @Email : houyupeng@ruc.edu.cn, xy_pan@foxmail.com, chenyushuo@ruc.edu.cn """ recbole.data.dataset @@ -26,8 +26,7 @@ from scipy.sparse import coo_matrix from recbole.data.interaction import Interaction -from recbole.utils import FeatureSource, FeatureType, get_local_time -from recbole.utils.utils import set_color +from recbole.utils import FeatureSource, FeatureType, get_local_time, set_color from recbole.utils.url import decide_download, download_url, extract_zip, makedirs, rename_atomic_files @@ -105,6 +104,7 @@ def _from_scratch(self): self._get_preset() self._get_field_from_config() self._load_data(self.dataset_name, self.dataset_path) + self._init_alias() self._data_processing() def _get_preset(self): @@ -117,6 +117,7 @@ def _get_preset(self): self.field2id_token = {} self.field2token_id = {} self.field2seqlen = self.config['seq_len'] or {} + self.alias = {} self._preloaded_weight = {} self.benchmark_filename_list = self.config['benchmark_filename'] @@ -441,6 +442,34 @@ def _load_feat(self, filepath, source): self.field2seqlen[field] = max(map(len, df[field].values)) return df + def _set_alias(self, alias_name, default_value): + alias = self.config[f'alias_of_{alias_name}'] or [] + alias = np.array(default_value + alias) + _, idx = np.unique(alias, return_index=True) + self.alias[alias_name] = alias[np.sort(idx)] + + def _init_alias(self): + """Set :attr:`alias_of_user_id` and :attr:`alias_of_item_id`. And set :attr:`_rest_fields`. + """ + self._set_alias('user_id', [self.uid_field]) + self._set_alias('item_id', [self.iid_field]) + + for alias_name_1, alias_1 in self.alias.items(): + for alias_name_2, alias_2 in self.alias.items(): + if alias_name_1 != alias_name_2: + intersect = np.intersect1d(alias_1, alias_2, assume_unique=True) + if len(intersect) > 0: + raise ValueError(f'`alias_of_{alias_name_1}` and `alias_of_{alias_name_2}` ' + f'should not have the same field {list(intersect)}.') + + self._rest_fields = self.token_like_fields + for alias_name, alias in self.alias.items(): + isin = np.isin(alias, self._rest_fields, assume_unique=True) + if isin.all() is False: + raise ValueError(f'`alias_of_{alias_name}` should not contain ' + f'non-token-like field {list(alias[~isin])}.') + self._rest_fields = np.setdiff1d(self._rest_fields, alias, assume_unique=True) + def _user_item_feat_preparation(self): """Sort :attr:`user_feat` and :attr:`item_feat` by ``user_id`` or ``item_id``. Missing values will be filled later. @@ -639,46 +668,32 @@ def _remove_duplication(self): def _filter_by_inter_num(self): """Filter by number of interaction. - Upper/Lower bounds can be set, only users/items between upper/lower bounds can be remained. + The interval of the number of interactions can be set, and only users/items whose number + of interactions is in the specified interval can be retained. See :doc:`../user_guide/data/data_args` for detail arg setting. Note: - Lower bound is also called k-core filtering, which means this method will filter loops - until all the users and items has at least k interactions. + Lower bound of the interval is also called k-core filtering, which means this method + will filter loops until all the users and items has at least k interactions. """ if self.uid_field is None or self.iid_field is None: return - max_user_inter_num = self.config['max_user_inter_num'] - min_user_inter_num = self.config['min_user_inter_num'] - max_item_inter_num = self.config['max_item_inter_num'] - min_item_inter_num = self.config['min_item_inter_num'] + user_inter_num_interval = self._parse_intervals_str(self.config['user_inter_num_interval']) + item_inter_num_interval = self._parse_intervals_str(self.config['item_inter_num_interval']) - if max_user_inter_num is None and min_user_inter_num is None: - user_inter_num = Counter() - else: - user_inter_num = Counter(self.inter_feat[self.uid_field].values) - - if max_item_inter_num is None and min_item_inter_num is None: - item_inter_num = Counter() - else: - item_inter_num = Counter(self.inter_feat[self.iid_field].values) + user_inter_num = Counter(self.inter_feat[self.uid_field].values) if user_inter_num_interval else Counter() + item_inter_num = Counter(self.inter_feat[self.iid_field].values) if item_inter_num_interval else Counter() while True: - ban_users = self._get_illegal_ids_by_inter_num( - field=self.uid_field, - feat=self.user_feat, - inter_num=user_inter_num, - max_num=max_user_inter_num, - min_num=min_user_inter_num - ) - ban_items = self._get_illegal_ids_by_inter_num( - field=self.iid_field, - feat=self.item_feat, - inter_num=item_inter_num, - max_num=max_item_inter_num, - min_num=min_item_inter_num - ) + ban_users = self._get_illegal_ids_by_inter_num(field=self.uid_field, + feat=self.user_feat, + inter_num=user_inter_num, + inter_interval=user_inter_num_interval) + ban_items = self._get_illegal_ids_by_inter_num(field=self.iid_field, + feat=self.item_feat, + inter_num=item_inter_num, + inter_interval=item_inter_num_interval) if len(ban_users) == 0 and len(ban_items) == 0: break @@ -704,80 +719,113 @@ def _filter_by_inter_num(self): self.logger.debug(f'[{len(dropped_index)}] dropped interactions.') self.inter_feat.drop(dropped_index, inplace=True) - def _get_illegal_ids_by_inter_num(self, field, feat, inter_num, max_num=None, min_num=None): + def _get_illegal_ids_by_inter_num(self, field, feat, inter_num, inter_interval=None): """Given inter feat, return illegal ids, whose inter num out of [min_num, max_num] Args: field (str): field name of user_id or item_id. feat (pandas.DataFrame): interaction feature. inter_num (Counter): interaction number counter. - max_num (int, optional): max number of interaction. Defaults to ``None``. - min_num (int, optional): min number of interaction. Defaults to ``None``. + inter_interval (list, optional): the allowed interval(s) of the number of interactions. + Defaults to ``None``. Returns: - set: illegal ids, whose inter num out of [min_num, max_num] + set: illegal ids, whose inter num out of inter_intervals. """ self.logger.debug( - set_color('get_illegal_ids_by_inter_num', 'blue') + - f': field=[{field}], max_num=[{max_num}], min_num=[{min_num}]' - ) - - max_num = max_num or np.inf - min_num = min_num or -1 + set_color('get_illegal_ids_by_inter_num', 'blue') + f': field=[{field}], inter_interval=[{inter_interval}]') + + if inter_interval is not None: + if len(inter_interval) > 1: + self.logger.warning(f'More than one interval of interaction number are given!') - ids = {id_ for id_ in inter_num if inter_num[id_] < min_num or inter_num[id_] > max_num} + ids = {id_ for id_ in inter_num if not self._within_intervals(inter_num[id_], inter_interval)} if feat is not None: + min_num = inter_interval[0][1] if inter_interval else -1 for id_ in feat[field].values: if inter_num[id_] < min_num: ids.add(id_) self.logger.debug(f'[{len(ids)}] illegal_ids_by_inter_num, field=[{field}]') return ids - def _filter_by_field_value(self): - """Filter features according to its values. - """ - filter_field = [] - filter_field += self._drop_by_value(self.config['lowest_val'], lambda x, y: x < y) - filter_field += self._drop_by_value(self.config['highest_val'], lambda x, y: x > y) - filter_field += self._drop_by_value(self.config['equal_val'], lambda x, y: x != y) - filter_field += self._drop_by_value(self.config['not_equal_val'], lambda x, y: x == y) + def _parse_intervals_str(self, intervals_str): + """Given string of intervals, return the list of endpoints tuple, where a tuple corresponds to an interval. - def _reset_index(self): - """Reset index for all feats in :attr:`feat_name_list`. + Args: + intervals_str (str): the string of intervals, such as "(0,1];[3,4)". + + Returns: + list of endpoint tuple, such as [('(', 0, 1.0 , ']'), ('[', 3.0, 4.0 , ')')]. """ - for feat_name in self.feat_name_list: - feat = getattr(self, feat_name) - if feat.empty: - raise ValueError('Some feat is empty, please check the filtering settings.') - feat.reset_index(drop=True, inplace=True) + if intervals_str is None: + return None - def _drop_by_value(self, val, cmp): - """Drop illegal rows by value. + endpoints = [] + for endpoint_pair_str in str(intervals_str).split(';'): + endpoint_pair_str = endpoint_pair_str.strip() + left_bracket, right_bracket = endpoint_pair_str[0], endpoint_pair_str[-1] + endpoint_pair = endpoint_pair_str[1:-1].split(',') + if not (len(endpoint_pair) == 2 and left_bracket in ['(', '['] and right_bracket in [')', ']']): + self.logger.warning(f'{endpoint_pair_str} is an illegal interval!') + continue - Args: - val (dict): value that compared to. - cmp (Callable): return False if a row need to be dropped + def str2npnum(num_str): + if num_str.lower() in ["inf", "-inf"]: + return np.inf if num_str.lower() == "inf" else -np.inf + else: + try: + return float(num_str) + except ValueError: + raise ValueError(f'Str {num_str} in interval can not be converted to numeric.') - Returns: - field names that used to compare with val. + left_point, right_point = str2npnum(endpoint_pair[0]), str2npnum(endpoint_pair[1]) + if left_point > right_point: + self.logger.warning(f'{endpoint_pair_str} is an illegal interval!') + + endpoints.append((left_bracket, left_point, right_point, right_bracket)) + return endpoints + + def _within_intervals(self, num, intervals): + """ return Ture if the num is in the intervals. + + Note: + return true when the intervals is None. + """ + result = True + for i, (left_bracket, left_point, right_point, right_bracket) in enumerate(intervals): + temp_result = num >= left_point if left_bracket == '[' else num > left_point + temp_result &= num <= right_point if right_bracket == ']' else num < right_point + result = temp_result if i == 0 else result | temp_result + return result + + def _filter_by_field_value(self): + """Filter features according to its values. """ - if val is None: - return [] - self.logger.debug(set_color('drop_by_value', 'blue') + f': val={val}') - filter_field = [] - for field in val: + val_intervals = [] if self.config['val_interval'] is None else self.config['val_interval'] + self.logger.debug(set_color('drop_by_value', 'blue') + f': val={val_intervals}') + + for field in val_intervals: if field not in self.field2type: raise ValueError(f'Field [{field}] not defined in dataset.') if self.field2type[field] not in {FeatureType.FLOAT, FeatureType.FLOAT_SEQ}: raise ValueError(f'Field [{field}] is not float-like field in dataset, which can\'t be filter.') + + field_val_interval = self._parse_intervals_str(val_intervals[field]) for feat_name in self.feat_name_list: feat = getattr(self, feat_name) if field in feat: - feat.drop(feat.index[cmp(feat[field].values, val[field])], inplace=True) - filter_field.append(field) - return filter_field + feat.drop(feat.index[~self._within_intervals(feat[field].values, field_val_interval)], inplace=True) + + def _reset_index(self): + """Reset index for all feats in :attr:`feat_name_list`. + """ + for feat_name in self.feat_name_list: + feat = getattr(self, feat_name) + if feat.empty: + raise ValueError('Some feat is empty, please check the filtering settings.') + feat.reset_index(drop=True, inplace=True) def _del_col(self, feat, field): """Delete columns @@ -841,41 +889,7 @@ def _set_label_by_threshold(self): raise ValueError(f'Field [{field}] not in inter_feat.') self._del_col(self.inter_feat, field) - def _get_fields_in_same_space(self): - """Parsing ``config['fields_in_same_space']``. See :doc:`../user_guide/data/data_args` for detail arg setting. - - Note: - - Each field can only exist ONCE in ``config['fields_in_same_space']``. - - user_id and item_id can not exist in ``config['fields_in_same_space']``. - - only token-like fields can exist in ``config['fields_in_same_space']``. - """ - fields_in_same_space = self.config['fields_in_same_space'] or [] - fields_in_same_space = [set(_) for _ in fields_in_same_space] - additional = [] - token_like_fields = self.token_like_fields - for field in token_like_fields: - count = 0 - for field_set in fields_in_same_space: - if field in field_set: - count += 1 - if count == 0: - additional.append({field}) - elif count == 1: - continue - else: - raise ValueError(f'Field [{field}] occurred in `fields_in_same_space` more than one time.') - - for field_set in fields_in_same_space: - if self.uid_field in field_set and self.iid_field in field_set: - raise ValueError('uid_field and iid_field can\'t in the same ID space') - for field in field_set: - if field not in token_like_fields: - raise ValueError(f'Field [{field}] is not a token-like field.') - - fields_in_same_space.extend(additional) - return fields_in_same_space - - def _get_remap_list(self, field_set): + def _get_remap_list(self, field_list): """Transfer set of fields in the same remapping space into remap list. If ``uid_field`` or ``iid_field`` in ``field_set``, @@ -883,7 +897,7 @@ def _get_remap_list(self, field_set): then field in :attr:`user_feat` or :attr:`item_feat` will be remapped next, finally others. Args: - field_set (set): Set of fields in the same remapping space + field_list (numpy.ndarray): List of fields in the same remapping space. Returns: list: @@ -893,29 +907,23 @@ def _get_remap_list(self, field_set): They will be concatenated in order, and remapped together. """ + remap_list = [] - for field, feat in zip([self.uid_field, self.iid_field], [self.user_feat, self.item_feat]): - if field in field_set: - field_set.remove(field) - remap_list.append((self.inter_feat, field, FeatureType.TOKEN)) - if feat is not None: - remap_list.append((feat, field, FeatureType.TOKEN)) - for field in field_set: - source = self.field2source[field] - if isinstance(source, FeatureSource): - source = source.value - feat = getattr(self, f'{source}_feat') + for field in field_list: ftype = self.field2type[field] - remap_list.append((feat, field, ftype)) + for feat in self.field2feats(field): + remap_list.append((feat, field, ftype)) return remap_list def _remap_ID_all(self): - """Get ``config['fields_in_same_space']`` firstly, and remap each. + """Remap all token-like fields. """ - fields_in_same_space = self._get_fields_in_same_space() - self.logger.debug(set_color('fields_in_same_space', 'blue') + f': {fields_in_same_space}') - for field_set in fields_in_same_space: - remap_list = self._get_remap_list(field_set) + for alias in self.alias.values(): + remap_list = self._get_remap_list(alias) + self._remap(remap_list) + + for field in self._rest_fields: + remap_list = self._get_remap_list(np.array([field])) self._remap(remap_list) def _concat_remaped_tokens(self, remap_list): @@ -934,7 +942,7 @@ def _concat_remaped_tokens(self, remap_list): if ftype == FeatureType.TOKEN: tokens.append(feat[field].values) elif ftype == FeatureType.TOKEN_SEQ: - tokens.append(feat[field].agg(np.concatenate)) + tokens.append(feat[field].reset_index(drop=True).agg(np.concatenate)) split_point = np.cumsum(list(map(len, tokens)))[:-1] tokens = np.concatenate(tokens) return tokens, split_point @@ -1068,6 +1076,24 @@ def copy_field_property(self, dest_field, source_field): self.field2source[dest_field] = self.field2source[source_field] self.field2seqlen[dest_field] = self.field2seqlen[source_field] + def field2feats(self, field): + if field not in self.field2source: + raise ValueError(f'Field [{field}] not defined in dataset.') + if field == self.uid_field: + feats = [self.inter_feat] + if self.user_feat is not None: + feats.append(self.user_feat) + elif field == self.iid_field: + feats = [self.inter_feat] + if self.item_feat is not None: + feats.append(self.item_feat) + else: + source = self.field2source[field] + if not isinstance(source, str): + source = source.value + feats = [getattr(self, f'{source}_feat')] + return feats + def token2id(self, field, tokens): """Map external tokens to internal ids. @@ -1421,7 +1447,7 @@ def sort(self, by, ascending=True): """ self.inter_feat.sort(by=by, ascending=ascending) - def build(self, eval_setting): + def build(self): """Processing dataset according to evaluation setting, including Group, Order and Split. See :class:`~recbole.config.eval_setting.EvalSetting` for details. @@ -1439,24 +1465,41 @@ def build(self, eval_setting): datasets = [self.copy(self.inter_feat[start:end]) for start, end in zip([0] + cumsum[:-1], cumsum)] return datasets - ordering_args = eval_setting.ordering_args - if ordering_args['strategy'] == 'shuffle': + # ordering + ordering_args = self.config['eval_args']['order'] + if ordering_args == 'RO': self.shuffle() - elif ordering_args['strategy'] == 'by': - self.sort(by=ordering_args['field'], ascending=ordering_args['ascending']) - - group_field = eval_setting.group_field - - split_args = eval_setting.split_args - if split_args['strategy'] == 'by_ratio': - datasets = self.split_by_ratio(split_args['ratios'], group_by=group_field) - elif split_args['strategy'] == 'by_value': - raise NotImplementedError() - elif split_args['strategy'] == 'loo': - datasets = self.leave_one_out(group_by=group_field, leave_one_num=split_args['leave_one_num']) + elif ordering_args == 'TO': + self.sort(by=self.config['TIME_FIELD']) else: - datasets = self - + raise NotImplementedError(f'The ordering_method [{ordering_args}] has not been implemented.') + + # splitting & groupping + split_args = self.config['eval_args']['split'] + if split_args is None: + raise ValueError('The split_args in eval_args should not be None.') + if isinstance(split_args, dict) != True: + raise ValueError(f'The split_args [{split_args}] should be a dict.') + + split_mode = list(split_args.keys())[0] + assert len(split_args.keys()) == 1 + group_by = self.config['eval_args']['group_by'] + if split_mode == 'RS': + if isinstance(split_args['RS'], list) != True: + raise ValueError(f'The value of "RS" [{split_args}] should be a list.') + if group_by == 'none': + datasets = self.split_by_ratio(split_args['RS'], group_by=None) + elif group_by == 'user': + datasets = self.split_by_ratio(split_args['RS'], group_by=self.config['USER_ID_FIELD']) + else: + raise NotImplementedError(f'The grouping method [{group_by}] has not been implemented.') + elif split_mode == 'LS': + if isinstance(split_args['LS'], int) != True: + raise ValueError(f'The value of "LS" [{split_args}] should be a int.') + datasets = self.leave_one_out(group_by=self.config['USER_ID_FIELD'], leave_one_num=split_args['LS']) + else: + raise NotImplementedError(f'The spliting_method [{split_mode}] has not been implemented.') + return datasets def save(self, filepath): diff --git a/recbole/data/dataset/kg_dataset.py b/recbole/data/dataset/kg_dataset.py index 5bf4ab1c9..e630946fc 100644 --- a/recbole/data/dataset/kg_dataset.py +++ b/recbole/data/dataset/kg_dataset.py @@ -20,8 +20,7 @@ from scipy.sparse import coo_matrix from recbole.data.dataset import Dataset -from recbole.utils import FeatureSource, FeatureType -from recbole.utils.utils import set_color +from recbole.utils import FeatureSource, FeatureType, set_color from recbole.utils.url import decide_download, download_url, extract_zip @@ -59,7 +58,7 @@ class KnowledgeBasedDataset(Dataset): Note: :attr:`entity_field` doesn't exist exactly. It's only a symbol, - representing entity features. E.g. it can be written into ``config['fields_in_same_space']``. + representing entity features. ``[UI-Relation]`` is a special relation token. """ @@ -67,10 +66,6 @@ class KnowledgeBasedDataset(Dataset): def __init__(self, config): super().__init__(config) - def _get_preset(self): - super()._get_preset() - self.field2ent_level = {} - def _get_field_from_config(self): super()._get_field_from_config() @@ -84,10 +79,6 @@ def _get_field_from_config(self): self.logger.debug(set_color('relation_field', 'blue') + f': {self.relation_field}') self.logger.debug(set_color('entity_field', 'blue') + f': {self.entity_field}') - def _data_processing(self): - self._set_field2ent_level() - super()._data_processing() - def _data_filtering(self): super()._data_filtering() self._filter_link() @@ -157,9 +148,6 @@ def _build_feat_name_list(self): feat_name_list.append('kg_feat') return feat_name_list - def _restore_saved_dataset(self, saved_dataset): - raise NotImplementedError() - def save(self, filepath): raise NotImplementedError() @@ -197,192 +185,86 @@ def _check_link(self, link): assert self.entity_field in link, link_warn_message.format(self.entity_field) assert self.iid_field in link, link_warn_message.format(self.iid_field) - def _get_fields_in_same_space(self): - """Parsing ``config['fields_in_same_space']``. See :doc:`../user_guide/data/data_args` for detail arg setting. - - Note: - - Each field can only exist ONCE in ``config['fields_in_same_space']``. - - user_id and item_id can not exist in ``config['fields_in_same_space']``. - - only token-like fields can exist in ``config['fields_in_same_space']``. - - ``head_entity_id`` and ``target_entity_id`` should be remapped with ``item_id``. - """ - fields_in_same_space = super()._get_fields_in_same_space() - fields_in_same_space = [_ for _ in fields_in_same_space if not self._contain_ent_field(_)] - ent_fields = self._get_ent_fields_in_same_space() - for field_set in fields_in_same_space: - if self.iid_field in field_set: - field_set.update(ent_fields) - return fields_in_same_space - - def _contain_ent_field(self, field_set): - """Return True if ``field_set`` contains entity fields. - """ - flag = False - flag |= self.head_entity_field in field_set - flag |= self.tail_entity_field in field_set - flag |= self.entity_field in field_set - return flag - - def _get_ent_fields_in_same_space(self): - """Return ``field_set`` that should be remapped together with entities. - """ - fields_in_same_space = super()._get_fields_in_same_space() - - ent_fields = {self.head_entity_field, self.tail_entity_field} - for field_set in fields_in_same_space: - if self._contain_ent_field(field_set): - field_set = self._remove_ent_field(field_set) - ent_fields.update(field_set) - self.logger.debug(set_color('ent_fields', 'blue') + f': {fields_in_same_space}') - return ent_fields - - def _remove_ent_field(self, field_set): - """Delete entity fields from ``field_set``. - """ - for field in [self.head_entity_field, self.tail_entity_field, self.entity_field]: - if field in field_set: - field_set.remove(field) - return field_set - - def _set_field2ent_level(self): - """For fields that remapped together with ``item_id``, - set their levels as ``rec``, otherwise as ``ent``. - """ - fields_in_same_space = self._get_fields_in_same_space() - for field_set in fields_in_same_space: - if self.iid_field in field_set: - for field in field_set: - self.field2ent_level[field] = 'rec' - ent_fields = self._get_ent_fields_in_same_space() - for field in ent_fields: - self.field2ent_level[field] = 'ent' - - def _fields_by_ent_level(self, ent_level): - """Given ``ent_level``, return all the field name of this level. - """ - ret = [] - for field in self.field2ent_level: - if self.field2ent_level[field] == ent_level: - ret.append(field) - return ret - - @property - def rec_level_ent_fields(self): - """Get entity fields remapped together with ``item_id``. - - Returns: - list: List of field names. + def _init_alias(self): + """Add :attr:`alias_of_entity_id`, :attr:`alias_of_relation_id` and update :attr:`_rest_fields`. """ - return self._fields_by_ent_level('rec') + self._set_alias('entity_id', [self.head_entity_field, self.tail_entity_field]) + self._set_alias('relation_id', [self.relation_field]) - @property - def ent_level_ent_fields(self): - """Get entity fields remapped together with ``entity_id``. + super()._init_alias() - Returns: - list: List of field names. - """ - return self._fields_by_ent_level('ent') - - def _remap_entities_by_link(self): - """Map entity tokens from fields in ``ent`` level - to item tokens according to ``.link``. - """ - for ent_field in self.ent_level_ent_fields: - source = self.field2source[ent_field] - if not isinstance(source, str): - source = source.value - feat = getattr(self, f'{source}_feat') - entity_list = feat[ent_field].values - for i, entity_id in enumerate(entity_list): - if entity_id in self.entity2item: - entity_list[i] = self.entity2item[entity_id] - feat[ent_field] = entity_list + self._rest_fields = np.setdiff1d(self._rest_fields, [self.entity_field], assume_unique=True) def _get_rec_item_token(self): """Get set of entity tokens from fields in ``rec`` level. """ - field_set = set(self.rec_level_ent_fields) - remap_list = self._get_remap_list(field_set) + remap_list = self._get_remap_list(self.alias['item_id']) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) def _get_entity_token(self): """Get set of entity tokens from fields in ``ent`` level. """ - field_set = set(self.ent_level_ent_fields) - remap_list = self._get_remap_list(field_set) + remap_list = self._get_remap_list(self.alias['entity_id']) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) - def _reset_ent_remapID(self, field, new_id_token): - token2id = {} - for i, token in enumerate(new_id_token): - token2id[token] = i - idmap = {} - for i, token in enumerate(self.field2id_token[field]): - if token not in token2id: - continue - new_idx = token2id[token] - idmap[i] = new_idx - source = self.field2source[field] - if not isinstance(source, str): - source = source.value - if source == 'item_id': - feats = [self.inter_feat] - if self.item_feat is not None: - feats.append(self.item_feat) - else: - feats = [getattr(self, f'{source}_feat')] - for feat in feats: - old_idx = feat[field].values - new_idx = np.array([idmap[_] for _ in old_idx]) - feat[field] = new_idx - - def _sort_remaped_entities(self, item_tokens): - item2order = {} - for token in self.field2id_token[self.iid_field]: - if token == '[PAD]': - item2order[token] = 0 - elif token in item_tokens and token not in self.item2entity: - item2order[token] = 1 - elif token in self.item2entity or token in self.entity2item: - item2order[token] = 2 + def _reset_ent_remapID(self, field, idmap, id2token, token2id): + self.field2id_token[field] = id2token + self.field2token_id[field] = token2id + for feat in self.field2feats(field): + ftype = self.field2type[field] + if ftype == FeatureType.TOKEN: + old_idx = feat[field].values else: - item2order[token] = 3 - item_ent_token_list = list(self.field2id_token[self.iid_field]) - item_ent_token_list.sort(key=lambda t: item2order[t]) - item_ent_token_list = np.array(item_ent_token_list) - order_list = [item2order[_] for _ in item_ent_token_list] - order_cnt = Counter(order_list) - layered_num = [] - for i in range(4): - layered_num.append(order_cnt[i]) - layered_num = np.cumsum(np.array(layered_num)) - new_id_token = item_ent_token_list[:layered_num[-2]] - new_token_id = {t: i for i, t in enumerate(new_id_token)} - for field in self.rec_level_ent_fields: - self._reset_ent_remapID(field, new_id_token) - self.field2id_token[field] = new_id_token - self.field2token_id[field] = new_token_id - new_id_token = item_ent_token_list[:layered_num[-1]] - new_id_token = [self.item2entity[_] if _ in self.item2entity else _ for _ in new_id_token] - new_token_id = {t: i for i, t in enumerate(new_id_token)} - for field in self.ent_level_ent_fields: - self._reset_ent_remapID(field, item_ent_token_list[:layered_num[-1]]) - self.field2id_token[field] = new_id_token - self.field2token_id[field] = new_token_id - self.field2id_token[self.entity_field] = new_id_token - self.field2token_id[self.entity_field] = new_token_id + old_idx = feat[field].agg(np.concatenate) - def _remap_ID_all(self): - """Firstly, remap entities and items all together. Then sort entity tokens, - then three kinds of entities can be apart away from each other. + new_idx = idmap[old_idx] + + if ftype == FeatureType.TOKEN: + feat[field] = new_idx + else: + split_point = np.cumsum(feat[field].agg(len))[:-1] + feat[field] = np.split(new_idx, split_point) + + def _merge_item_and_entity(self): + """Merge item-id and entity-id into the same id-space. """ - self._remap_entities_by_link() - item_tokens = self._get_rec_item_token() + item_token = self.field2id_token[self.iid_field] + entity_token = self.field2id_token[self.head_entity_field] + item_num = len(item_token) + link_num = len(self.item2entity) + entity_num = len(entity_token) + + # reset item id + item_priority = np.array([token in self.item2entity for token in item_token]) + item_order = np.argsort(item_priority, kind='stable') + item_id_map = np.zeros_like(item_order) + item_id_map[item_order] = np.arange(item_num) + new_item_id2token = item_token[item_order] + new_item_token2id = {t: i for i, t in enumerate(new_item_id2token)} + for field in self.alias['item_id']: + self._reset_ent_remapID(field, item_id_map, new_item_id2token, new_item_token2id) + + # reset entity id + entity_priority = np.array([token != '[PAD]' and token not in self.entity2item for token in entity_token]) + entity_order = np.argsort(entity_priority, kind='stable') + entity_id_map = np.zeros_like(entity_order) + for i in entity_order[1:link_num + 1]: + entity_id_map[i] = new_item_token2id[self.entity2item[entity_token[i]]] + entity_id_map[entity_order[link_num + 1:]] = np.arange(item_num, item_num + entity_num - link_num - 1) + new_entity_id2token = np.concatenate([new_item_id2token, entity_token[entity_order[link_num + 1:]]]) + for i in range(item_num - link_num, item_num): + new_entity_id2token[i] = self.item2entity[new_entity_id2token[i]] + new_entity_token2id = {t: i for i, t in enumerate(new_entity_id2token)} + for field in self.alias['entity_id']: + self._reset_ent_remapID(field, entity_id_map, new_entity_id2token, new_entity_token2id) + self.field2id_token[self.entity_field] = new_entity_id2token + self.field2token_id[self.entity_field] = new_entity_token2id + + def _remap_ID_all(self): super()._remap_ID_all() - self._sort_remaped_entities(item_tokens) + self._merge_item_and_entity() self.field2token_id[self.relation_field]['[UI-Relation]'] = len(self.field2id_token[self.relation_field]) self.field2id_token[self.relation_field] = np.append(self.field2id_token[self.relation_field], '[UI-Relation]') diff --git a/recbole/data/dataset/sequential_dataset.py b/recbole/data/dataset/sequential_dataset.py index 103c81296..45327c123 100644 --- a/recbole/data/dataset/sequential_dataset.py +++ b/recbole/data/dataset/sequential_dataset.py @@ -3,20 +3,21 @@ # @Email : chenyushuo@ruc.edu.cn # UPDATE: -# @Time : 2020/9/16 -# @Author : Yushuo Chen -# @Email : chenyushuo@ruc.edu.cn +# @Time : 2020/9/16, 2021/7/1, 2021/7/11 +# @Author : Yushuo Chen, Xingyu Pan, Yupeng Hou +# @Email : chenyushuo@ruc.edu.cn, xy_pan@foxmail.com, houyupeng@ruc.edu.cn """ recbole.data.sequential_dataset ############################### """ -import copy - import numpy as np +import torch from recbole.data.dataset import Dataset +from recbole.data.interaction import Interaction +from recbole.utils.enum_type import FeatureType, FeatureSource class SequentialDataset(Dataset): @@ -25,20 +26,51 @@ class SequentialDataset(Dataset): which can accelerate the data loader. Attributes: - uid_list (numpy.ndarray): List of user id after augmentation. + max_item_list_len (int): Max length of historical item list. + item_list_length_field (str): Field name for item lists' length. + """ - item_list_index (numpy.ndarray): List of indexes of item sequence after augmentation. + def __init__(self, config): + self.max_item_list_len = config['MAX_ITEM_LIST_LENGTH'] + self.item_list_length_field = config['ITEM_LIST_LENGTH_FIELD'] + super().__init__(config) + if config['benchmark_filename'] is not None: + self._benchmark_presets() - target_index (numpy.ndarray): List of indexes of target item id after augmentation. + def _change_feat_format(self): + """Change feat format from :class:`pandas.DataFrame` to :class:`Interaction`, + then perform data augmentation. + """ + super()._change_feat_format() - item_list_length (numpy.ndarray): List of item sequences' length after augmentation. + if self.config['benchmark_filename'] is not None: + return + self.logger.debug('Augmentation for sequential recommendation.') + self.data_augmentation() - """ + def _aug_presets(self): + list_suffix = self.config['LIST_SUFFIX'] + for field in self.inter_feat: + if field != self.uid_field: + list_field = field + list_suffix + setattr(self, f'{field}_list_field', list_field) + ftype = self.field2type[field] - def __init__(self, config): - super().__init__(config) + if ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ]: + list_ftype = FeatureType.TOKEN_SEQ + else: + list_ftype = FeatureType.FLOAT_SEQ - def prepare_data_augmentation(self): + if ftype in [FeatureType.TOKEN_SEQ, FeatureType.FLOAT_SEQ]: + list_len = (self.max_item_list_len, self.field2seqlen[field]) + else: + list_len = self.max_item_list_len + + self.set_field_property(list_field, list_ftype, FeatureSource.INTERACTION, list_len) + + self.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1) + + def data_augmentation(self): """Augmentation processing for sequential dataset. E.g., ``u1`` has purchase sequence ````, @@ -54,14 +86,10 @@ def prepare_data_augmentation(self): ``u1, | i3`` ``u1, | i4`` - - Note: - Actually, we do not really generate these new item sequences. - One user's item sequence is stored only once in memory. - We store the index (slice) of each item sequence after augmentation, - which saves memory and accelerates a lot. """ - self.logger.debug('prepare_data_augmentation') + self.logger.debug('data_augmentation') + + self._aug_presets() self._check_field('uid_field', 'time_field') max_item_list_len = self.config['MAX_ITEM_LIST_LENGTH'] @@ -81,34 +109,41 @@ def prepare_data_augmentation(self): target_index.append(i) item_list_length.append(i - seq_start) - self.uid_list = np.array(uid_list) - self.item_list_index = np.array(item_list_index) - self.target_index = np.array(target_index) - self.item_list_length = np.array(item_list_length, dtype=np.int64) - self.mask = np.ones(len(self.inter_feat), dtype=np.bool) - - def leave_one_out(self, group_by, leave_one_num=1): - self.logger.debug(f'Leave one out, group_by=[{group_by}], leave_one_num=[{leave_one_num}].') - if group_by is None: - raise ValueError('Leave one out strategy require a group field.') - if group_by != self.uid_field: - raise ValueError('Sequential models require group by user.') - - self.prepare_data_augmentation() - grouped_index = self._grouped_index(self.uid_list) - next_index = self._split_index_by_leave_one_out(grouped_index, leave_one_num) - - self._drop_unused_col() - next_ds = [] - for index in next_index: - ds = copy.copy(self) - for field in ['uid_list', 'item_list_index', 'target_index', 'item_list_length']: - setattr(ds, field, np.array(getattr(ds, field)[index])) - setattr(ds, 'mask', np.ones(len(self.inter_feat), dtype=np.bool)) - next_ds.append(ds) - next_ds[0].mask[self.target_index[next_index[1] + next_index[2]]] = False - next_ds[1].mask[self.target_index[next_index[2]]] = False - return next_ds + uid_list = np.array(uid_list) + item_list_index = np.array(item_list_index) + target_index = np.array(target_index) + item_list_length = np.array(item_list_length, dtype=np.int64) + + new_length = len(item_list_index) + new_data = self.inter_feat[target_index] + new_dict = { + self.item_list_length_field: torch.tensor(item_list_length), + } + + for field in self.inter_feat: + if field != self.uid_field: + list_field = getattr(self, f'{field}_list_field') + list_len = self.field2seqlen[list_field] + shape = (new_length, list_len) if isinstance(list_len, int) else (new_length,) + list_len + list_ftype = self.field2type[list_field] + dtype = torch.int64 if list_ftype in [FeatureType.TOKEN, FeatureType.TOKEN_SEQ] else torch.float64 + new_dict[list_field] = torch.zeros(shape, dtype=dtype) + + value = self.inter_feat[field] + for i, (index, length) in enumerate(zip(item_list_index, item_list_length)): + new_dict[list_field][i][:length] = value[index] + + new_data.update(Interaction(new_dict)) + self.inter_feat = new_data + + def _benchmark_presets(self): + list_suffix = self.config['LIST_SUFFIX'] + for field in self.inter_feat: + if field + list_suffix in self.inter_feat: + list_field = field + list_suffix + setattr(self, f'{field}_list_field', list_field) + self.set_field_property(self.item_list_length_field, FeatureType.TOKEN, FeatureSource.INTERACTION, 1) + self.inter_feat[self.item_list_length_field] = self.inter_feat[self.item_id_list_field].agg(len) def inter_matrix(self, form='coo', value_field=None): """Get sparse matrix that describe interactions between user_id and item_id. @@ -127,29 +162,33 @@ def inter_matrix(self, form='coo', value_field=None): if not self.uid_field or not self.iid_field: raise ValueError('dataset does not exist uid/iid, thus can not converted to sparse matrix.') - self.logger.warning( - 'Load interaction matrix may lead to label leakage from testing phase, this implementation ' - 'only provides the interactions corresponding to specific phase' - ) - local_inter_feat = self.inter_feat[self.mask] # TODO: self.mask will applied to _history_matrix() in future + l1_idx = (self.inter_feat[self.item_list_length_field] == 1) + l1_inter_dict = self.inter_feat[l1_idx].interaction + new_dict = {} + list_suffix = self.config['LIST_SUFFIX'] + candidate_field_set = set() + for field in l1_inter_dict: + if field != self.uid_field and field + list_suffix in l1_inter_dict: + candidate_field_set.add(field) + new_dict[field] = torch.cat([self.inter_feat[field], l1_inter_dict[field + list_suffix][:,0]]) + elif (not field.endswith(list_suffix)) and (field != self.item_list_length_field): + new_dict[field] = torch.cat([self.inter_feat[field], l1_inter_dict[field]]) + local_inter_feat = Interaction(new_dict) return self._create_sparse_matrix(local_inter_feat, self.uid_field, self.iid_field, form, value_field) - def build(self, eval_setting): - self._change_feat_format() - - ordering_args = eval_setting.ordering_args - if ordering_args['strategy'] == 'shuffle': - raise ValueError('Ordering strategy `shuffle` is not supported in sequential models.') - elif ordering_args['strategy'] == 'by': - if ordering_args['field'] != self.time_field: - raise ValueError('Sequential models require `TO` (time ordering) strategy.') - if ordering_args['ascending'] is not True: - raise ValueError('Sequential models require `time_field` to sort in ascending order.') - - group_field = eval_setting.group_field - - split_args = eval_setting.split_args - if split_args['strategy'] == 'loo': - return self.leave_one_out(group_by=group_field, leave_one_num=split_args['leave_one_num']) - else: - ValueError('Sequential models require `loo` (leave one out) split strategy.') + def build(self): + """Processing dataset according to evaluation setting, including Group, Order and Split. + See :class:`~recbole.config.eval_setting.EvalSetting` for details. + + Args: + eval_setting (:class:`~recbole.config.eval_setting.EvalSetting`): + Object contains evaluation settings, which guide the data processing procedure. + + Returns: + list: List of built :class:`Dataset`. + """ + ordering_args = self.config['eval_args']['order'] + if ordering_args != 'TO': + raise ValueError(f'The ordering args for sequential recommendation has to be \'TO\'') + + return super().build() diff --git a/recbole/data/utils.py b/recbole/data/utils.py index 197d85810..4828ba644 100644 --- a/recbole/data/utils.py +++ b/recbole/data/utils.py @@ -3,7 +3,7 @@ # @Email : houyupeng@ruc.edu.cn # UPDATE: -# @Time : 2020/10/19, 2020/9/17, 2020/8/31, 2021/2/20, 2021/3/1 +# @Time : 2021/7/9, 2020/9/17, 2020/8/31, 2021/2/20, 2021/3/1 # @Author : Yupeng Hou, Yushuo Chen, Kaiyuan Li, Haoran Cheng, Jiawei Guan # @Email : houyupeng@ruc.edu.cn, chenyushuo@ruc.edu.cn, tsotfsk@outlook.com, chenghaoran29@foxmail.com, guanjw@ruc.edu.cn @@ -17,11 +17,9 @@ import os import pickle -from recbole.config import EvalSetting from recbole.data.dataloader import * from recbole.sampler import KGSampler, Sampler, RepeatableSampler -from recbole.utils import ModelType, ensure_dir, get_local_time -from recbole.utils.utils import set_color +from recbole.utils import ModelType, ensure_dir, get_local_time, set_color def create_dataset(config): @@ -55,125 +53,6 @@ def create_dataset(config): return Dataset(config) -def data_preparation(config, dataset, save=False): - """Split the dataset by :attr:`config['eval_setting']` and call :func:`dataloader_construct` to create - corresponding dataloader. - - Args: - config (Config): An instance object of Config, used to record parameter information. - dataset (Dataset): An instance object of Dataset, which contains all interaction records. - save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset. - Defaults to ``False``. - - Returns: - tuple: - - train_data (AbstractDataLoader): The dataloader for training. - - valid_data (AbstractDataLoader): The dataloader for validation. - - test_data (AbstractDataLoader): The dataloader for testing. - """ - model_type = config['MODEL_TYPE'] - - es = EvalSetting(config) - - built_datasets = dataset.build(es) - train_dataset, valid_dataset, test_dataset = built_datasets - phases = ['train', 'valid', 'test'] - sampler = None - logger = getLogger() - train_neg_sample_args = config['train_neg_sample_args'] - eval_neg_sample_args = es.neg_sample_args - - # Training - train_kwargs = { - 'config': config, - 'dataset': train_dataset, - 'batch_size': config['train_batch_size'], - 'dl_format': config['MODEL_INPUT_TYPE'], - 'shuffle': True, - } - if train_neg_sample_args['strategy'] != 'none': - if dataset.label_field in dataset.inter_feat: - raise ValueError( - f'`training_neg_sample_num` should be 0 ' - f'if inter_feat have label_field [{dataset.label_field}].' - ) - if model_type != ModelType.SEQUENTIAL: - sampler = Sampler(phases, built_datasets, train_neg_sample_args['distribution']) - else: - sampler = RepeatableSampler(phases, dataset, train_neg_sample_args['distribution']) - train_kwargs['sampler'] = sampler.set_phase('train') - train_kwargs['neg_sample_args'] = train_neg_sample_args - if model_type == ModelType.KNOWLEDGE: - kg_sampler = KGSampler(dataset, train_neg_sample_args['distribution']) - train_kwargs['kg_sampler'] = kg_sampler - - dataloader = get_data_loader('train', config, train_neg_sample_args) - logger.info( - set_color('Build', 'pink') + set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' + - set_color('[train]', 'yellow') + ' with format ' + set_color(f'[{train_kwargs["dl_format"]}]', 'yellow') - ) - if train_neg_sample_args['strategy'] != 'none': - logger.info( - set_color('[train]', 'pink') + set_color(' Negative Sampling', 'blue') + f': {train_neg_sample_args}' - ) - else: - logger.info(set_color('[train]', 'pink') + set_color(' No Negative Sampling', 'yellow')) - logger.info( - set_color('[train]', 'pink') + set_color(' batch_size', 'cyan') + ' = ' + - set_color(f'[{train_kwargs["batch_size"]}]', 'yellow') + ', ' + set_color('shuffle', 'cyan') + ' = ' + - set_color(f'[{train_kwargs["shuffle"]}]\n', 'yellow') - ) - train_data = dataloader(**train_kwargs) - - # Evaluation - eval_kwargs = { - 'config': config, - 'batch_size': config['eval_batch_size'], - 'dl_format': InputType.POINTWISE, - 'shuffle': False, - } - valid_kwargs = {'dataset': valid_dataset} - test_kwargs = {'dataset': test_dataset} - if eval_neg_sample_args['strategy'] != 'none': - if dataset.label_field in dataset.inter_feat: - raise ValueError( - f'It can not validate with `{es.es_str[1]}` ' - f'when inter_feat have label_field [{dataset.label_field}].' - ) - if sampler is None: - if model_type != ModelType.SEQUENTIAL: - sampler = Sampler(phases, built_datasets, eval_neg_sample_args['distribution']) - else: - sampler = RepeatableSampler(phases, dataset, eval_neg_sample_args['distribution']) - else: - sampler.set_distribution(eval_neg_sample_args['distribution']) - eval_kwargs['neg_sample_args'] = eval_neg_sample_args - valid_kwargs['sampler'] = sampler.set_phase('valid') - test_kwargs['sampler'] = sampler.set_phase('test') - valid_kwargs.update(eval_kwargs) - test_kwargs.update(eval_kwargs) - - dataloader = get_data_loader('evaluation', config, eval_neg_sample_args) - logger.info( - set_color('Build', 'pink') + set_color(f' [{dataloader.__name__}]', 'yellow') + ' for ' + - set_color('[evaluation]', 'yellow') + ' with format ' + set_color(f'[{eval_kwargs["dl_format"]}]', 'yellow') - ) - logger.info(es) - logger.info( - set_color('[evaluation]', 'pink') + set_color(' batch_size', 'cyan') + ' = ' + - set_color(f'[{eval_kwargs["batch_size"]}]', 'yellow') + ', ' + set_color('shuffle', 'cyan') + ' = ' + - set_color(f'[{eval_kwargs["shuffle"]}]\n', 'yellow') - ) - - valid_data = dataloader(**valid_kwargs) - test_data = dataloader(**test_kwargs) - - if save: - save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data)) - - return train_data, valid_data, test_data - - def save_split_dataloaders(config, dataloaders): """Save split dataloaders. @@ -204,125 +83,137 @@ def load_split_dataloaders(saved_dataloaders_file): return dataloaders -def get_data_loader(name, config, neg_sample_args): - """Return a dataloader class according to :attr:`config` and :attr:`eval_setting`. +def data_preparation(config, dataset, save=False): + """Split the dataset by :attr:`config['eval_args']` and create training, validation and test dataloader. Args: - name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. - neg_sample_args (dict) : Settings of negative sampling. + dataset (Dataset): An instance object of Dataset, which contains all interaction records. + save (bool, optional): If ``True``, it will call :func:`save_datasets` to save split dataset. + Defaults to ``False``. Returns: - type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. + tuple: + - train_data (AbstractDataLoader): The dataloader for training. + - valid_data (AbstractDataLoader): The dataloader for validation. + - test_data (AbstractDataLoader): The dataloader for testing. """ - register_table = { - 'DIN': _get_DIN_data_loader, - 'DIEN': _get_DIEN_data_loader, - "MultiDAE": _get_AE_data_loader, - "MultiVAE": _get_AE_data_loader, - 'MacridVAE': _get_AE_data_loader, - 'CDAE': _get_AE_data_loader, - 'ENMF': _get_AE_data_loader, - 'RaCT': _get_AE_data_loader, - 'RecVAE': _get_AE_data_loader - } - - if config['model'] in register_table: - return register_table[config['model']](name, config, neg_sample_args) - - model_type_table = { - ModelType.GENERAL: 'General', - ModelType.TRADITIONAL: 'General', - ModelType.CONTEXT: 'Context', - ModelType.SEQUENTIAL: 'Sequential', - ModelType.DECISIONTREE: 'DecisionTree', - } - neg_sample_strategy_table = { - 'none': 'DataLoader', - 'by': 'NegSampleDataLoader', - 'full': 'FullDataLoader', - } model_type = config['MODEL_TYPE'] - neg_sample_strategy = neg_sample_args['strategy'] - dataloader_module = importlib.import_module('recbole.data.dataloader') - - if model_type in model_type_table and neg_sample_strategy in neg_sample_strategy_table: - dataloader_name = model_type_table[model_type] + neg_sample_strategy_table[neg_sample_strategy] - return getattr(dataloader_module, dataloader_name) - elif model_type == ModelType.KNOWLEDGE: - if neg_sample_strategy == 'by': - if name == 'train': - return KnowledgeBasedDataLoader - else: - return GeneralNegSampleDataLoader - elif neg_sample_strategy == 'full': - return GeneralFullDataLoader - elif neg_sample_strategy == 'none': - raise NotImplementedError( - 'The use of external negative sampling for knowledge model has not been implemented' - ) + built_datasets = dataset.build() + + train_dataset, valid_dataset, test_dataset = built_datasets + train_sampler, valid_sampler, test_sampler = create_samplers(config, dataset, built_datasets) + + if model_type != ModelType.KNOWLEDGE: + train_data = get_dataloader(config, 'train')(config, train_dataset, train_sampler, shuffle=True) else: - raise NotImplementedError(f'Model_type [{model_type}] has not been implemented.') + kg_sampler = KGSampler(dataset, config['train_neg_sample_args']['distribution']) + train_data = get_dataloader(config, 'train')(config, train_dataset, train_sampler, kg_sampler, shuffle=True) + + valid_data = get_dataloader(config, 'valid')(config, valid_dataset, valid_sampler, shuffle=False) + test_data = get_dataloader(config, 'test')(config, test_dataset, test_sampler, shuffle=False) + if save: + save_split_dataloaders(config, dataloaders=(train_data, valid_data, test_data)) + + return train_data, valid_data, test_data -def _get_DIN_data_loader(name, config, neg_sample_args): - """Customized function for DIN to get correct dataloader class. + +def get_dataloader(config, phase): + """Return a dataloader class according to :attr:`config` and :attr:`phase`. Args: - name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. - neg_sample_args : Settings of negative sampling. + phase (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. Returns: - type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. + type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`. """ - neg_sample_strategy = neg_sample_args['strategy'] - if neg_sample_strategy == 'none': - return SequentialDataLoader - elif neg_sample_strategy == 'by': - return SequentialNegSampleDataLoader - elif neg_sample_strategy == 'full': - return SequentialFullDataLoader + register_table = { + "MultiDAE": _get_AE_dataloader, + "MultiVAE": _get_AE_dataloader, + 'MacridVAE': _get_AE_dataloader, + 'CDAE': _get_AE_dataloader, + 'ENMF': _get_AE_dataloader, + 'RaCT': _get_AE_dataloader, + 'RecVAE': _get_AE_dataloader, + } + + if config['model'] in register_table: + return register_table[config['model']](config, phase) + + model_type = config['MODEL_TYPE'] + if phase == 'train': + if model_type != ModelType.KNOWLEDGE: + return TrainDataLoader + else: + return KnowledgeBasedDataLoader + else: + eval_strategy = config['eval_neg_sample_args']['strategy'] + if eval_strategy in {'none', 'by'}: + return NegSampleEvalDataLoader + elif eval_strategy == 'full': + return FullSortEvalDataLoader -def _get_DIEN_data_loader(name, config, neg_sample_args): - """Customized function for DIEN to get correct dataloader class. +def _get_AE_dataloader(config, phase): + """Customized function for VAE models to get correct dataloader class. Args: - name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. - neg_sample_args : Settings of negative sampling. + phase (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. Returns: - type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. + type: The dataloader class that meets the requirements in :attr:`config` and :attr:`phase`. """ - neg_sample_strategy = neg_sample_args['strategy'] - if neg_sample_strategy == 'none': - return DIENDataLoader - elif neg_sample_strategy == 'by': - return DIENNegSampleDataLoader - elif neg_sample_strategy == 'full': - return DIENFullDataLoader + if phase == 'train': + return UserDataLoader + else: + eval_strategy = config['eval_neg_sample_args']['strategy'] + if eval_strategy in {'none', 'by'}: + return NegSampleEvalDataLoader + elif eval_strategy == 'full': + return FullSortEvalDataLoader -def _get_AE_data_loader(name, config, neg_sample_args): - """Customized function for Multi-DAE and Multi-VAE to get correct dataloader class. +def create_samplers(config, dataset, built_datasets): + """Create sampler for training, validation and testing. Args: - name (str): The stage of dataloader. It can only take two values: 'train' or 'evaluation'. config (Config): An instance object of Config, used to record parameter information. - neg_sample_args (dict): Settings of negative sampling. + dataset (Dataset): An instance object of Dataset, which contains all interaction records. + built_datasets (list of Dataset): A list of split Dataset, which contains dataset for + training, validation and testing. Returns: - type: The dataloader class that meets the requirements in :attr:`config` and :attr:`eval_setting`. + tuple: + - train_sampler (AbstractSampler): The sampler for training. + - valid_sampler (AbstractSampler): The sampler for validation. + - test_sampler (AbstractSampler): The sampler for testing. """ - neg_sample_strategy = neg_sample_args['strategy'] - if name == "train": - return UserDataLoader - else: - if neg_sample_strategy == 'none': - return GeneralDataLoader - elif neg_sample_strategy == 'by': - return GeneralNegSampleDataLoader - elif neg_sample_strategy == 'full': - return GeneralFullDataLoader + model_type = config['MODEL_TYPE'] + phases = ['train', 'valid', 'test'] + train_neg_sample_args = config['train_neg_sample_args'] + eval_neg_sample_args = config['eval_neg_sample_args'] + sampler = None + train_sampler, valid_sampler, test_sampler = None, None, None + + if train_neg_sample_args['strategy'] != 'none': + if model_type != ModelType.SEQUENTIAL: + sampler = Sampler(phases, built_datasets, train_neg_sample_args['distribution']) + else: + sampler = RepeatableSampler(phases, dataset, train_neg_sample_args['distribution']) + train_sampler = sampler.set_phase('train') + + if eval_neg_sample_args['strategy'] != 'none': + if sampler is None: + if model_type != ModelType.SEQUENTIAL: + sampler = Sampler(phases, built_datasets, eval_neg_sample_args['distribution']) + else: + sampler = RepeatableSampler(phases, dataset, eval_neg_sample_args['distribution']) + else: + sampler.set_distribution(eval_neg_sample_args['distribution']) + valid_sampler = sampler.set_phase('valid') + test_sampler = sampler.set_phase('test') + + return train_sampler, valid_sampler, test_sampler diff --git a/recbole/evaluator/abstract_evaluator.py b/recbole/evaluator/abstract_evaluator.py new file mode 100644 index 000000000..f49ef7b49 --- /dev/null +++ b/recbole/evaluator/abstract_evaluator.py @@ -0,0 +1,136 @@ +# -*- encoding: utf-8 -*- +# @Time : 2020/10/21 +# @Author : Kaiyuan Li +# @email : tsotfsk@outlook.com + +# UPDATE +# @Time : 2020/10/21, 2020/12/18, 2021/7/1 +# @Author : Kaiyuan Li, Zhichao Feng, Xingyu Pan +# @email : tsotfsk@outlook.com, fzcbupt@gmail.com, xy_oan@foxmail.com + +""" +recbole.evaluator.abstract_evaluator +##################################### +""" + +import numpy as np +import torch +from torch.nn.utils.rnn import pad_sequence + + +class BaseEvaluator(object): + """:class:`BaseEvaluator` is an object which supports + the evaluation of the model. It is called by :class:`Trainer`. + + Note: + If you want to inherit this class and implement your own evaluator class, + you must implement the following functions. + + Args: + config (Config): The config of evaluator. + + """ + + def __init__(self, config, metrics): + self.metrics = metrics + self.full = ('full' in config['eval_args']['mode']) + self.precision = config['metric_decimal_place'] + + def collect(self, *args): + """get the intermediate results for each batch, it is called at the end of each batch""" + raise NotImplementedError + + def evaluate(self, *args): + """calculate the metrics of all batches, it is called at the end of each epoch""" + raise NotImplementedError + + def _calculate_metrics(self, *args): + """ to calculate the metrics""" + raise NotImplementedError + + +class GroupedEvaluator(BaseEvaluator): + """:class:`GroupedEvaluator` is an object which supports the evaluation of the model. + + Note: + If you want to implement a new group-based metric, + you may need to inherit this class + + """ + + def __init__(self, config, metrics): + super().__init__(config, metrics) + pass + + def sample_collect(self, scores_tensor, user_len_list): + """padding scores_tensor. It is called when evaluation sample distribution is `uniform` or `popularity`. + + """ + scores_list = torch.split(scores_tensor, user_len_list, dim=0) + padding_score = pad_sequence(scores_list, batch_first=True, padding_value=-np.inf) # n_users x items + return padding_score + + def full_sort_collect(self, scores_tensor, user_len_list): + """it is called when evaluation sample distribution is `full`. + + """ + return scores_tensor.view(len(user_len_list), -1) + + def get_score_matrix(self, scores_tensor, user_len_list): + """get score matrix. + + Args: + scores_tensor (tensor): the tensor of model output with size of `(N, )` + user_len_list(list): number of all items + + """ + if self.full: + scores_matrix = self.full_sort_collect(scores_tensor, user_len_list) + else: + scores_matrix = self.sample_collect(scores_tensor, user_len_list) + return scores_matrix + + +class IndividualEvaluator(BaseEvaluator): + """:class:`IndividualEvaluator` is an object which supports the evaluation of the model. + + Note: + If you want to implement a new non-group-based metric, + you may need to inherit this class + + """ + + def __init__(self, config, metrics): + super().__init__(config, metrics) + self._check_args() + + def sample_collect(self, true_scores, pred_scores): + """It is called when evaluation sample distribution is `uniform` or `popularity`. + + """ + return torch.stack((true_scores, pred_scores.detach()), dim=1) + + def full_sort_collect(self, true_scores, pred_scores): + """it is called when evaluation sample distribution is `full`. + + """ + raise NotImplementedError('full sort can\'t use IndividualEvaluator') + + def get_score_matrix(self, true_scores, pred_scores): + """get score matrix + + Args: + true_scores (tensor): the label of predicted items + pred_scores (tensor): the tensor of model output with a size of `(N, )` + + """ + if self.full: + scores_matrix = self.full_sort_collect(true_scores, pred_scores) + else: + scores_matrix = self.sample_collect(true_scores, pred_scores) + + return scores_matrix + + def _check_args(self): + if self.full: + raise NotImplementedError('full sort can\'t use IndividualEvaluator') diff --git a/recbole/evaluator/collector.py b/recbole/evaluator/collector.py index 75c2430c2..075e7a6c8 100644 --- a/recbole/evaluator/collector.py +++ b/recbole/evaluator/collector.py @@ -65,7 +65,7 @@ def __init__(self, config): self.config = config self.data_struct = DataStruct() self.register = Register(config) - self.full = ('full' in config['eval_setting']) + self.full = ('full' in config['eval_args']['mode']) self.topk = self.config['topk'] self.topk_idx = None diff --git a/recbole/evaluator/evaluator.py b/recbole/evaluator/evaluator.py index 005b2866b..ee34f1660 100644 --- a/recbole/evaluator/evaluator.py +++ b/recbole/evaluator/evaluator.py @@ -60,7 +60,7 @@ def _check_args(self): # Check Loss if set(self.metrics) & set(loss_metrics): - is_full = 'full' in self.config['eval_setting'] + is_full = 'full' in self.config['eval_args']['mode'] if is_full: raise NotImplementedError('Full sort evaluation do not match the metrics!') diff --git a/recbole/evaluator/metrics.py b/recbole/evaluator/metrics.py index 009b5f835..ca1a055aa 100644 --- a/recbole/evaluator/metrics.py +++ b/recbole/evaluator/metrics.py @@ -570,6 +570,70 @@ def get_gini(self, item_matrix, num_items): return gini_index +class TailPercentage: + r"""It computes the percentage of long-tail items in recommendation items. + + For further details, please refer to the `paper ` + + .. math:: + \mathrm {TailPercentage}=\frac{1}{|U|} \sum_{u \in U} \frac{\sum_{i \in R_{u}} {1| i \in T}}{|R_{u}|} + + :math:`n` is the number of all items. + :math:`T` is the set of long-tail items, + which is a portion of items that appear in training data seldomly. + + Note: + If you want to use this metric, please set the parameter 'tail_ratio' in the config + which can be an integer or a float in (0,1]. Otherwise it will default to 0.1. + """ + + def __init__(self, config): + self.topk = config['topk'] + self.decimal_place = config['metric_decimal_place'] + self.tail = config['tail_ratio'] + if self.tail is None or self.tail <= 0: + self.tail = 0.1 + + def used_info(self, dataobject): + """get the matrix of recommendation items and number of items in total item set""" + item_matrix = dataobject.get('rec.items') + count_items = dataobject.get('data.count_items') + return item_matrix.numpy(), dict(count_items) + + def get_tail(self, item_matrix, count_items): + if self.tail > 1: + tail_items = [item for item, cnt in count_items.items() if cnt <= self.tail] + else: + count_items = sorted(count_items.items(), key=lambda kv: (kv[1], kv[0])) + cut = max(int(len(count_items) * self.tail), 1) + count_items = count_items[:cut] + tail_items = [item for item, cnt in count_items] + value = np.zeros_like(item_matrix) + for i in range(item_matrix.shape[0]): + row = item_matrix[i, :] + for j in range(row.shape[0]): + value[i][j] = 1 if row[j] in tail_items else 0 + return value + + def calculate_metric(self, dataobject): + item_matrix, count_items = self.used_info(dataobject) + result = self.metric_info(self.get_tail(item_matrix, count_items)) + metric_dict = self.topk_result('tailpercentage', result) + return metric_dict + + def metric_info(self, values): + return values.cumsum(axis=1) / np.arange(1, values.shape[1] + 1) + + def topk_result(self, metric, value): + """match the metric value to the `k` and put them in `dictionary` form""" + metric_dict = {} + avg_result = value.mean(axis=0) + for k in self.topk: + key = '{}@{}'.format(metric, k) + metric_dict[key] = round(avg_result[k-1], self.decimal_place) + return metric_dict + + metrics_dict = { 'ndcg': NDCG, 'hit': Hit, @@ -585,5 +649,6 @@ def get_gini(self, item_matrix, num_items): 'itemcoverage': ItemCoverage, 'averagepopularity': AveragePopularity, 'giniindex': GiniIndex, - 'shannonentropy': ShannonEntropy + 'shannonentropy': ShannonEntropy, + 'tailpercentage': TailPercentage } diff --git a/recbole/evaluator/register.py b/recbole/evaluator/register.py index f3c2ed839..66a275d26 100644 --- a/recbole/evaluator/register.py +++ b/recbole/evaluator/register.py @@ -25,7 +25,7 @@ 'averagepopularity': ['rec.topk', 'rec.items', 'data.count_items'], 'giniindex': ['rec.topk', 'rec.items', 'data.num_items'], 'shannonentropy': ['rec.topk', 'rec.items'], - + 'tailpercentage': ['rec.topk', 'rec.items', 'data.count_items'], 'gauc': ['rec.meanrank'], # Sign in for full ranking metrics diff --git a/recbole/model/abstract_recommender.py b/recbole/model/abstract_recommender.py index 92547725d..5dd6661a1 100644 --- a/recbole/model/abstract_recommender.py +++ b/recbole/model/abstract_recommender.py @@ -19,8 +19,7 @@ import torch.nn as nn from recbole.model.layers import FMEmbedding, FMFirstOrderLinear -from recbole.utils import ModelType, InputType, FeatureSource, FeatureType -from recbole.utils.utils import set_color +from recbole.utils import ModelType, InputType, FeatureSource, FeatureType, set_color class AbstractRecommender(nn.Module): diff --git a/recbole/model/general_recommender/macridvae.py b/recbole/model/general_recommender/macridvae.py index c382b34ac..eadc8a359 100644 --- a/recbole/model/general_recommender/macridvae.py +++ b/recbole/model/general_recommender/macridvae.py @@ -3,6 +3,11 @@ # @Author : Yihong Guo # @Email : gyihong@hotmail.com +# UPDATE +# @Time : 2021/6/30, +# @Author : Xingyu Pan +# @email : xy_pan@foxmail.com + r""" MacridVAE ################################################ @@ -16,7 +21,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.distributions.one_hot_categorical import OneHotCategorical from recbole.model.abstract_recommender import GeneralRecommender from recbole.model.init import xavier_normal_initialization @@ -106,11 +110,10 @@ def forward(self, rating_matrix): cates_logits = torch.matmul(items, cores.transpose(0, 1)) / self.tau if self.nogb: - cates = torch.softmax(cates_logits, dim=1) + cates = torch.softmax(cates_logits, dim=-1) else: - cates_dist = OneHotCategorical(logits=cates_logits) - cates_sample = cates_dist.sample() - cates_mode = torch.softmax(cates_logits, dim=1) + cates_sample = F.gumbel_softmax(cates_logits, tau=1, hard=False, dim=-1) + cates_mode = torch.softmax(cates_logits, dim=-1) cates = (self.training * cates_sample + (1 - self.training) * cates_mode) probs = None diff --git a/recbole/properties/dataset/ml-100k.yaml b/recbole/properties/dataset/ml-100k.yaml index 9e0beff79..7c229f814 100644 --- a/recbole/properties/dataset/ml-100k.yaml +++ b/recbole/properties/dataset/ml-100k.yaml @@ -31,18 +31,16 @@ unused_col: ~ # Filtering rm_dup_inter: ~ -lowest_val: ~ -highest_val: ~ -equal_val: ~ -not_equal_val: ~ +val_interval: ~ filter_inter_by_user_or_item: True -max_user_inter_num: ~ -min_user_inter_num: ~ -max_item_inter_num: ~ -min_item_inter_num: ~ +user_inter_num_interval: ~ +item_inter_num_interval: ~ # Preprocessing -fields_in_same_space: ~ +alias_of_user_id: ~ +alias_of_item_id: ~ +alias_of_entity_id: ~ +alias_of_relation_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: True diff --git a/recbole/properties/dataset/sample.yaml b/recbole/properties/dataset/sample.yaml index d9869e5e2..301e5b3c8 100644 --- a/recbole/properties/dataset/sample.yaml +++ b/recbole/properties/dataset/sample.yaml @@ -26,18 +26,16 @@ additional_feat_suffix: ~ # Filtering rm_dup_inter: ~ -lowest_val: ~ -highest_val: ~ -equal_val: ~ -not_equal_val: ~ +val_interval: ~ filter_inter_by_user_or_item: True -max_user_inter_num: ~ -min_user_inter_num: 0 -max_item_inter_num: ~ -min_item_inter_num: 0 +user_inter_num_interval: "[0,inf)" +item_inter_num_interval: "[0,inf)" # Preprocessing -fields_in_same_space: ~ +alias_of_user_id: ~ +alias_of_item_id: ~ +alias_of_entity_id: ~ +alias_of_relation_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: ~ @@ -54,9 +52,5 @@ TAIL_ENTITY_ID_FIELD: tail_id RELATION_ID_FIELD: relation_id ENTITY_ID_FIELD: entity_id -# Social Model Needed -SOURCE_ID_FIELD: source_id -TARGET_ID_FIELD: target_id - # Benchmark .inter benchmark_filename: ~ diff --git a/recbole/properties/overall.yaml b/recbole/properties/overall.yaml index a99b33176..11b2ca08f 100644 --- a/recbole/properties/overall.yaml +++ b/recbole/properties/overall.yaml @@ -23,11 +23,11 @@ weight_decay: 0.0 draw_loss_pic: False # evaluation settings -eval_setting: RO_RS,full -group_by_user: True -split_ratio: [0.8,0.1,0.1] -leave_one_num: 2 -real_time_process: False +eval_args: + split: {'RS':[0.8,0.1,0.1]} + group_by: user + order: RO + mode: full metrics: ["Recall","MRR","NDCG","Hit","Precision"] topk: [10] valid_metric: MRR@10 diff --git a/recbole/properties/quick_start_config/context-aware.yaml b/recbole/properties/quick_start_config/context-aware.yaml index cdb71f098..360ec3fbb 100644 --- a/recbole/properties/quick_start_config/context-aware.yaml +++ b/recbole/properties/quick_start_config/context-aware.yaml @@ -1,5 +1,8 @@ -eval_setting: RO_RS -group_by_user: False +eval_args: + split: {'RS':[0.8,0.1,0.1]} + order: RO + group_by: none + mode: none training_neg_sample_num: 0 metrics: ['AUC', 'LogLoss'] valid_metric: AUC \ No newline at end of file diff --git a/recbole/properties/quick_start_config/sequential.yaml b/recbole/properties/quick_start_config/sequential.yaml index 87c0fa053..d9e144449 100644 --- a/recbole/properties/quick_start_config/sequential.yaml +++ b/recbole/properties/quick_start_config/sequential.yaml @@ -1 +1,4 @@ -eval_setting: TO_LS,full \ No newline at end of file +eval_args: + split: {'LS': 2} + order: TO + mode: full diff --git a/recbole/properties/quick_start_config/sequential_DIN.yaml b/recbole/properties/quick_start_config/sequential_DIN.yaml index 58b8db955..8d6b01edc 100644 --- a/recbole/properties/quick_start_config/sequential_DIN.yaml +++ b/recbole/properties/quick_start_config/sequential_DIN.yaml @@ -1,3 +1,6 @@ -eval_setting: TO_LS, uni100 +eval_args: + split: {'LS': 2} + order: TO + mode: uni100 metrics: ['AUC', 'LogLoss'] -valid_metric: AUC \ No newline at end of file +valid_metric: AUC diff --git a/recbole/quick_start/quick_start.py b/recbole/quick_start/quick_start.py index 89836de99..720c23e3b 100644 --- a/recbole/quick_start/quick_start.py +++ b/recbole/quick_start/quick_start.py @@ -11,8 +11,7 @@ from recbole.config import Config from recbole.data import create_dataset, data_preparation -from recbole.utils import init_logger, get_model, get_trainer, init_seed -from recbole.utils.utils import set_color +from recbole.utils import init_logger, get_model, get_trainer, init_seed, set_color def run_recbole(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True): diff --git a/recbole/sampler/sampler.py b/recbole/sampler/sampler.py index 7bd282622..ef6a4995c 100644 --- a/recbole/sampler/sampler.py +++ b/recbole/sampler/sampler.py @@ -222,7 +222,7 @@ def get_used_ids(self): raise ValueError( 'Some users have interacted with all items, ' 'which we can not sample negative items for them. ' - 'Please set `max_user_inter_num` to filter those users.' + 'Please set `user_inter_num_interval` to filter those users.' ) return used_item_id @@ -243,11 +243,12 @@ def set_phase(self, phase): new_sampler.used_ids = new_sampler.used_ids[phase] return new_sampler - def sample_by_user_ids(self, user_ids, num): + def sample_by_user_ids(self, user_ids, item_ids, num): """Sampling by user_ids. Args: user_ids (numpy.ndarray or list): Input user_ids. + item_ids (numpy.ndarray or list): Input item_ids. num (int): Number of sampled item_ids for each user_id. Returns: @@ -383,11 +384,12 @@ def get_used_ids(self): """ return np.array([set() for _ in range(self.n_users)]) - def sample_by_user_ids(self, user_ids, num): + def sample_by_user_ids(self, user_ids, item_ids, num): """Sampling by user_ids. Args: user_ids (numpy.ndarray or list): Input user_ids. + item_ids (numpy.ndarray or list): Input item_ids. num (int): Number of sampled item_ids for each user_id. Returns: @@ -398,7 +400,8 @@ def sample_by_user_ids(self, user_ids, num): item_id[len(user_ids) * (num - 1) + 1] is sampled for user_ids[1]; ...; and so on. """ try: - return self.sample_by_key_ids(user_ids, num) + self.used_ids = np.array([{i} for i in item_ids]) + return self.sample_by_key_ids(np.arange(len(user_ids)), num) except IndexError: for user_id in user_ids: if user_id < 0 or user_id >= self.n_users: diff --git a/recbole/trainer/trainer.py b/recbole/trainer/trainer.py index d92e677ca..b69d2fdab 100644 --- a/recbole/trainer/trainer.py +++ b/recbole/trainer/trainer.py @@ -30,8 +30,7 @@ from recbole.data.interaction import Interaction from recbole.evaluator import Evaluator, Collector from recbole.utils import ensure_dir, get_local_time, early_stopping, calculate_valid_score, dict2str, \ - DataLoaderType, KGDataLoaderState -from recbole.utils.utils import set_color + DataLoaderType, KGDataLoaderState, get_tensorboard, set_color class AbstractTrainer(object): @@ -77,6 +76,7 @@ def __init__(self, config, model): super(Trainer, self).__init__(config, model) self.logger = getLogger() + self.tensorboard = get_tensorboard(self.logger) self.learner = config['learner'] self.learning_rate = config['learning_rate'] self.epochs = config['epochs'] @@ -92,7 +92,6 @@ def __init__(self, config, model): saved_model_file = '{}-{}.pth'.format(self.config['model'], get_local_time()) self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file) self.weight_decay = config['weight_decay'] - self.draw_loss_pic = config['draw_loss_pic'] self.start_epoch = 0 self.cur_step = 0 @@ -112,6 +111,12 @@ def _build_optimizer(self, params): Returns: torch.optim: the optimizer """ + if self.config['reg_weight'] and self.weight_decay and self.weight_decay * self.config['reg_weight'] > 0: + self.logger.warning( + 'The parameters [weight_decay] and [reg_weight] are specified simultaneously, ' + 'which may lead to double regularization.' + ) + if self.learner.lower() == 'adam': optimizer = optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) elif self.learner.lower() == 'sgd': @@ -246,6 +251,13 @@ def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses): train_loss_output += set_color('train loss', 'blue') + ': ' + des % losses return train_loss_output + ']' + def _add_train_loss_to_tensorboard(self, epoch_idx, losses, tag='Loss/Train'): + if isinstance(losses, tuple): + for idx, loss in enumerate(losses): + self.tensorboard.add_scalar(tag + str(idx), loss, epoch_idx) + else: + self.tensorboard.add_scalar(tag, losses, epoch_idx) + def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None): r"""Train the model based on the train data and the valid data. @@ -277,6 +289,7 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) if verbose: self.logger.info(train_loss_output) + self._add_train_loss_to_tensorboard(epoch_idx, train_loss) # eval if self.eval_step <= 0 or not valid_data: @@ -304,6 +317,8 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre if verbose: self.logger.info(valid_score_output) self.logger.info(valid_result_output) + self.tensorboard.add_scalar('Vaild_score', valid_score, epoch_idx) + if update_flag: if saved: self._save_checkpoint(epoch_idx) @@ -321,9 +336,6 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre if verbose: self.logger.info(stop_output) break - if self.draw_loss_pic: - save_path = '{}-{}-train_loss.pdf'.format(self.config['model'], get_local_time()) - self.plot_train_loss(save_path=os.path.join(save_path)) return self.best_valid_score, self.best_valid_result def _full_sort_batch_eval(self, batched_data): @@ -429,30 +441,6 @@ def _spilt_predict(self, interaction, batch_size): result_list.append(result) return torch.cat(result_list, dim=0) - def plot_train_loss(self, show=True, save_path=None): - r"""Plot the train loss in each epoch - - Args: - show (bool, optional): Whether to show this figure, default: True - save_path (str, optional): The data path to save the figure, default: None. - If it's None, it will not be saved. - """ - import matplotlib.pyplot as plt - import time - epochs = list(self.train_loss_dict.keys()) - epochs.sort() - values = [float(self.train_loss_dict[epoch]) for epoch in epochs] - plt.plot(epochs, values) - my_x_ticks = np.arange(0, len(epochs), int(len(epochs) / 10)) - plt.xticks(my_x_ticks) - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.title(self.config['model'] + ' ' + time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time()))) - if show: - plt.show() - if save_path: - plt.savefig(save_path) - class KGTrainer(Trainer): r"""KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the @@ -547,6 +535,7 @@ def pretrain(self, train_data, verbose=True, show_progress=False): self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) if verbose: self.logger.info(train_loss_output) + self._add_train_loss_to_tensorboard(epoch_idx, train_loss) if (epoch_idx + 1) % self.config['save_step'] == 0: saved_model_file = os.path.join( @@ -618,6 +607,7 @@ def __init__(self, config, model): super(DecisionTreeTrainer, self).__init__(config, model) self.logger = getLogger() + self.tensorboard = get_tensorboard(self.logger) self.label_field = config['LABEL_FIELD'] self.convert_token_to_onehot = self.config['convert_token_to_onehot'] @@ -722,6 +712,7 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre if verbose: self.logger.info(valid_score_output) self.logger.info(valid_result_output) + self.tensorboard.add_scalar('Vaild_score', valid_score, epoch_idx) self.best_valid_score = valid_score self.best_valid_result = valid_result @@ -841,6 +832,7 @@ def pretrain(self, train_data, verbose=True, show_progress=False): self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) if verbose: self.logger.info(train_loss_output) + self._add_train_loss_to_tensorboard(epoch_idx, train_loss) if (epoch_idx + 1) % self.pretrain_epochs == 0: saved_model_file = os.path.join( @@ -1016,6 +1008,7 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss) if verbose: self.logger.info(train_loss_output) + self._add_train_loss_to_tensorboard(epoch_idx, train_loss) # eval if self.eval_step <= 0 or not valid_data: @@ -1043,6 +1036,7 @@ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progre if verbose: self.logger.info(valid_score_output) self.logger.info(valid_result_output) + self.tensorboard.add_scalar('Vaild_score', valid_score, epoch_idx) if update_flag: if saved: self._save_checkpoint(epoch_idx) diff --git a/recbole/utils/__init__.py b/recbole/utils/__init__.py index 22240e0aa..7dc6e7680 100644 --- a/recbole/utils/__init__.py +++ b/recbole/utils/__init__.py @@ -1,6 +1,6 @@ -from recbole.utils.logger import init_logger +from recbole.utils.logger import init_logger, set_color from recbole.utils.utils import get_local_time, ensure_dir, get_model, get_trainer, \ - early_stopping, calculate_valid_score, dict2str, init_seed + early_stopping, calculate_valid_score, dict2str, init_seed, get_tensorboard from recbole.utils.enum_type import * from recbole.utils.argument_list import * @@ -8,5 +8,5 @@ 'init_logger', 'get_local_time', 'ensure_dir', 'get_model', 'get_trainer', 'early_stopping', 'calculate_valid_score', 'dict2str', 'Enum', 'ModelType', 'DataLoaderType', 'KGDataLoaderState', 'EvaluatorType', 'InputType', 'FeatureType', 'FeatureSource', 'init_seed', 'general_arguments', 'training_arguments', - 'evaluation_arguments', 'dataset_arguments' + 'evaluation_arguments', 'dataset_arguments', 'get_tensorboard', 'set_color' ] diff --git a/recbole/utils/argument_list.py b/recbole/utils/argument_list.py index 632f442bd..ab6f17c14 100644 --- a/recbole/utils/argument_list.py +++ b/recbole/utils/argument_list.py @@ -10,7 +10,9 @@ 'reproducibility', 'state', 'data_path', + 'benchmark_filename', 'show_progress', + 'config_file' ] training_arguments = [ @@ -27,11 +29,8 @@ ] evaluation_arguments = [ - 'eval_setting', - 'group_by_user', - 'split_ratio', 'leave_one_num', - 'real_time_process', - 'metrics', 'topk', 'valid_metric', + 'eval_args', + 'metrics', 'topk', 'valid_metric', 'valid_metric_bigger', 'eval_batch_size', 'metric_decimal_place' ] @@ -45,9 +44,9 @@ 'ITEM_LIST_LENGTH_FIELD', 'LIST_SUFFIX', 'MAX_ITEM_LIST_LENGTH', 'POSITION_FIELD', 'HEAD_ENTITY_ID_FIELD', 'TAIL_ENTITY_ID_FIELD', 'RELATION_ID_FIELD', 'ENTITY_ID_FIELD', 'load_col', 'unload_col', 'unused_col', 'additional_feat_suffix', - 'max_user_inter_num', 'min_user_inter_num', 'max_item_inter_num', 'min_item_inter_num', - 'lowest_val', 'highest_val', 'equal_val', 'not_equal_val', - 'fields_in_same_space', + 'filter_inter_by_user_or_item', 'rm_dup_inter', + 'val_interval', 'user_inter_num_interval', 'item_inter_num_interval', + 'alias_of_user_id', 'alias_of_item_id', 'alias_of_entity_id', 'alias_of_relation_id', 'preload_weight', 'normalize_field', 'normalize_all' ] diff --git a/recbole/utils/case_study.py b/recbole/utils/case_study.py index bf96be0d4..d76f5b036 100644 --- a/recbole/utils/case_study.py +++ b/recbole/utils/case_study.py @@ -15,9 +15,6 @@ import numpy as np import torch -from recbole.data.dataloader.general_dataloader import GeneralFullDataLoader -from recbole.data.dataloader.sequential_dataloader import SequentialFullDataLoader - @torch.no_grad() def full_sort_scores(uid_series, model, test_data): @@ -38,21 +35,19 @@ def full_sort_scores(uid_series, model, test_data): dataset = test_data.dataset model.eval() - if isinstance(test_data, GeneralFullDataLoader): + if not test_data.is_sequential: index = np.isin(test_data.user_df[uid_field].numpy(), uid_series) input_interaction = test_data.user_df[index] history_item = test_data.uid2history_item[input_interaction[uid_field].numpy()] history_row = torch.cat([torch.full_like(hist_iid, i) for i, hist_iid in enumerate(history_item)]) history_col = torch.cat(list(history_item)) history_index = history_row, history_col - elif isinstance(test_data, SequentialFullDataLoader): + else: index = np.isin(test_data.uid_list, uid_series) input_interaction = test_data.augmentation( test_data.item_list_index[index], test_data.target_index[index], test_data.item_list_length[index] ) history_index = None - else: - raise NotImplementedError # Get scores of all items try: diff --git a/recbole/utils/logger.py b/recbole/utils/logger.py index f3fb17185..8f4f70355 100644 --- a/recbole/utils/logger.py +++ b/recbole/utils/logger.py @@ -16,6 +16,7 @@ import logging import os import colorlog +import re from recbole.utils.utils import get_local_time, ensure_dir from colorama import init @@ -28,6 +29,29 @@ } +class RemoveColorFilter(logging.Filter): + def filter(self, record): + if record: + ansi_escape = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') + record.msg = ansi_escape.sub('', str(record.msg)) + return True + + +def set_color(log, color, highlight=True): + color_set = ['black', 'red', 'green', 'yellow', 'blue', 'pink', 'cyan', 'white'] + try: + index = color_set.index(color) + except: + index = len(color_set) - 1 + prev_log = '\033[' + if highlight: + prev_log += '1;3' + else: + prev_log += '0;3' + prev_log += str(index) + 'm' + return prev_log + log + '\033[0m' + + def init_logger(config): """ A logger that can show a message on standard output and write it into the @@ -70,12 +94,15 @@ def init_logger(config): level = logging.CRITICAL else: level = logging.INFO + fh = logging.FileHandler(logfilepath) fh.setLevel(level) fh.setFormatter(fileformatter) + remove_color_filter = RemoveColorFilter() + fh.addFilter(remove_color_filter) sh = logging.StreamHandler() sh.setLevel(level) sh.setFormatter(sformatter) - logging.basicConfig(level=level, handlers=[fh, sh]) + logging.basicConfig(level=level, handlers=[sh, fh]) diff --git a/recbole/utils/url.py b/recbole/utils/url.py index 15ec4d333..73c190512 100644 --- a/recbole/utils/url.py +++ b/recbole/utils/url.py @@ -50,10 +50,10 @@ def download_url(url, folder): logger = getLogger() if osp.exists(path) and osp.getsize(path) > 0: # pragma: no cover - logger.info('Using exist file', filename) + logger.info(f'Using exist file {filename}') return path - logger.info('Downloading', url) + logger.info(f'Downloading {url}') makedirs(folder) data = ur.urlopen(url) @@ -90,7 +90,7 @@ def extract_zip(path, folder): folder (string): The folder. ''' logger = getLogger() - logger.info('Extracting', path) + logger.info(f'Extracting {path}') with zipfile.ZipFile(path, 'r') as f: f.extractall(folder) @@ -106,7 +106,8 @@ def rename_atomic_files(folder, old_name, new_name): files = os.listdir(folder) for f in files: base, suf = os.path.splitext(f) - assert base == old_name + if base != old_name: + continue assert suf in {'.inter', '.user', '.item'} os.rename(os.path.join(folder, f), os.path.join(folder, new_name + suf)) diff --git a/recbole/utils/utils.py b/recbole/utils/utils.py index 6070f2e57..bdba6bf56 100644 --- a/recbole/utils/utils.py +++ b/recbole/utils/utils.py @@ -20,6 +20,7 @@ import numpy as np import torch +from torch.utils.tensorboard import SummaryWriter from recbole.utils.enum_type import ModelType @@ -192,16 +193,28 @@ def init_seed(seed, reproducibility): torch.backends.cudnn.deterministic = False -def set_color(log, color, highlight=True): - color_set = ['black', 'red', 'green', 'yellow', 'blue', 'pink', 'cyan', 'white'] - try: - index = color_set.index(color) - except: - index = len(color_set) - 1 - prev_log = '\033[' - if highlight: - prev_log += '1;3' - else: - prev_log += '0;3' - prev_log += str(index) + 'm' - return prev_log + log + '\033[0m' +def get_tensorboard(logger): + r""" Creates a SummaryWriter of Tensorboard that can log PyTorch models and metrics into a directory for + visualization within the TensorBoard UI. + For the convenience of the user, the naming rule of the SummaryWriter's log_dir is the same as the logger. + + Args: + logger: its output filename is used to name the SummaryWriter's log_dir. + If the filename is not available, we will name the log_dir according to the current time. + + Returns: + SummaryWriter: it will write out events and summaries to the event file. + """ + base_path = 'log_tensorboard' + + dir_name = None + for handler in logger.handlers: + if hasattr(handler, "baseFilename"): + dir_name = os.path.basename(getattr(handler, 'baseFilename')).split('.')[0] + break + if dir_name is None: + dir_name = '{}-{}'.format('model', get_local_time()) + + dir_path = os.path.join(base_path, dir_name) + writer = SummaryWriter(dir_path) + return writer \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9cf90921e..eccaef92e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -matplotlib>=3.1.3 torch>=1.7.0 numpy>=1.17.2 scipy==1.6.0 @@ -8,4 +7,5 @@ tqdm>=4.48.2 scikit_learn>=0.23.2 pyyaml>=5.1.0 colorlog==4.7.2 -colorama==0.4.4 \ No newline at end of file +colorama==0.4.4 +tensorboard>=2.5.0 \ No newline at end of file diff --git a/run_test.sh b/run_test.sh index 6d70c5171..e035d878e 100644 --- a/run_test.sh +++ b/run_test.sh @@ -7,7 +7,7 @@ echo "metrics tests finished" python -m pytest -v tests/config/test_config.py python -m pytest -v tests/config/test_overall.py export PYTHONPATH=. -python tests/config/test_command_line.py --use_gpu=False --valid_metric=Recall@10 --split_ratio=[0.7,0.2,0.1] --metrics=['Recall@10'] --epochs=200 --eval_setting='LO_RS' --learning_rate=0.3 +python tests/config/test_command_line.py --use_gpu=False --valid_metric=Recall@10 --metrics=['Recall@10'] --epochs=200 --learning_rate=0.3 echo "config tests finished" python -m pytest -v tests/evaluation_setting diff --git a/setup.py b/setup.py index 9e7b017ed..f99a82b63 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ install_requires = ['numpy>=1.17.2', 'torch>=1.7.0', 'scipy==1.6.0', 'pandas>=1.0.5', 'tqdm>=4.48.2', 'colorlog==4.7.2','colorama==0.4.4', - 'scikit_learn>=0.23.2', 'pyyaml>=5.1.0', 'matplotlib>=3.1.3'] + 'scikit_learn>=0.23.2', 'pyyaml>=5.1.0', 'tensorboard>=2.5.0'] setup_requires = [] diff --git a/tests/config/test_command_line.py b/tests/config/test_command_line.py index 627ba95f1..c4ea6ac01 100644 --- a/tests/config/test_command_line.py +++ b/tests/config/test_command_line.py @@ -2,6 +2,11 @@ # @Author : Shanlei Mu # @Email : slmu@ruc.edu.cn +# UPDATE: +# @Time : 2021/7/1 +# @Author : Xingyu Pan +# @Email : xy_pan@foxmail.com + from recbole.config import Config @@ -22,12 +27,10 @@ # command line assert config['use_gpu'] == False assert config['valid_metric'] == 'Recall@10' - assert config['split_ratio'] == [0.7, 0.2, 0.1] # assert config['metrics'] == ['Recall@10'] # bug # priority assert config['epochs'] == 200 - assert config['eval_setting'] == 'LO_RS' assert config['learning_rate'] == 0.3 print('------------------------------------------------------------') diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 681eb1355..865fb5894 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -2,6 +2,10 @@ # @Author : Shanlei Mu # @Email : slmu@ruc.edu.cn +# UPDATE: +# @Time : 2021/7/1 +# @Author : Xingyu Pan +# @Email : xy_pan@foxmail.com import os import unittest @@ -43,11 +47,7 @@ def test_default_settings(self): self.assertIsInstance(config['stopping_step'], int) self.assertIsInstance(config['checkpoint_dir'], str) - self.assertIsInstance(config['eval_setting'], str) - self.assertIsInstance(config['group_by_user'], bool) - self.assertIsInstance(config['split_ratio'], list) - self.assertIsInstance(config['leave_one_num'], int) - self.assertIsInstance(config['real_time_process'], bool) + self.assertIsInstance(config['eval_args'], dict) self.assertIsInstance(config['metrics'], list) self.assertIsInstance(config['topk'], list) self.assertIsInstance(config['valid_metric'], str) @@ -56,8 +56,11 @@ def test_default_settings(self): def test_default_context_settings(self): config = Config(model='FM', dataset='ml-100k') - self.assertEqual(config['eval_setting'], 'RO_RS') - self.assertEqual(config['group_by_user'], False) + self.assertEqual(config['eval_args']['split'], {'RS': [0.8,0.1,0.1]}) + self.assertEqual(config['eval_args']['order'], 'RO') + self.assertEqual(config['eval_args']['mode'],'none') + self.assertEqual(config['eval_args']['group_by'], 'none') + self.assertEqual(config['metrics'], ['AUC', 'LogLoss']) self.assertEqual(config['valid_metric'], 'AUC') self.assertEqual(config['training_neg_sample_num'], 0) @@ -67,15 +70,21 @@ def test_default_sequential_settings(self): 'training_neg_sample_num': 0 } config = Config(model='SASRec', dataset='ml-100k', config_dict=para_dict) - self.assertEqual(config['eval_setting'], 'TO_LS,full') - + self.assertEqual(config['eval_args']['split'], {'LS': 2}) + self.assertEqual(config['eval_args']['order'], 'TO') + self.assertEqual(config['eval_args']['mode'],'full') + self.assertEqual(config['eval_args']['group_by'], 'user') + def test_config_file_list(self): config = Config(model='BPR', dataset='ml-100k', config_file_list=config_file_list) self.assertEqual(config['model'], 'BPR') self.assertEqual(config['learning_rate'], 0.1) self.assertEqual(config['topk'], [5, 20]) - self.assertEqual(config['eval_setting'], 'TO_LS,full') + self.assertEqual(config['eval_args']['split'], {'LS': 2}) + self.assertEqual(config['eval_args']['order'], 'TO') + self.assertEqual(config['eval_args']['mode'],'full') + self.assertEqual(config['eval_args']['group_by'], 'user') def test_config_dict(self): config = Config(model='BPR', dataset='ml-100k', config_dict=parameters_dict) @@ -83,7 +92,10 @@ def test_config_dict(self): self.assertEqual(config['model'], 'BPR') self.assertEqual(config['learning_rate'], 0.2) self.assertEqual(config['topk'], [50, 100]) - self.assertEqual(config['eval_setting'], 'RO_RS,full') + self.assertEqual(config['eval_args']['split'], {'RS': [0.8, 0.1, 0.1]}) + self.assertEqual(config['eval_args']['order'], 'RO') + self.assertEqual(config['eval_args']['mode'],'full') + self.assertEqual(config['eval_args']['group_by'], 'user') # todo: add command line test examples def test_priority(self): @@ -92,7 +104,10 @@ def test_priority(self): self.assertEqual(config['learning_rate'], 0.2) # default, file, dict self.assertEqual(config['topk'], [50, 100]) # default, file, dict - self.assertEqual(config['eval_setting'], 'TO_LS,full') # default, file + self.assertEqual(config['eval_args']['split'], {'LS': 2}) + self.assertEqual(config['eval_args']['order'], 'TO') + self.assertEqual(config['eval_args']['mode'],'full') + self.assertEqual(config['eval_args']['group_by'], 'user') self.assertEqual(config['epochs'], 100) # default, dict diff --git a/tests/config/test_config_example.yaml b/tests/config/test_config_example.yaml index 741d51ad7..febce74e5 100644 --- a/tests/config/test_config_example.yaml +++ b/tests/config/test_config_example.yaml @@ -1,4 +1,7 @@ model: FM learning_rate: 0.1 topk: [5, 20] -eval_setting: 'TO_LS,full' \ No newline at end of file +eval_args: + split: {'LS': 2} + mode: full + order: TO \ No newline at end of file diff --git a/tests/config/test_overall.py b/tests/config/test_overall.py index 78b820cca..2ebcc345d 100644 --- a/tests/config/test_overall.py +++ b/tests/config/test_overall.py @@ -4,9 +4,10 @@ # @email : tsotfsk@outlook.com # UPDATE: -# @Time : 2020/11/17 +# @Time : 2021/7/1 # @Author : Xingyu Pan -# @Email : panxy@ruc.edu.cn +# @Email : xy_pan@foxmail.com + import os import sys import unittest @@ -22,7 +23,7 @@ def run_parms(parm_dict, extra_dict=None): config_dict = { 'epochs': 1, - 'state': 'CRITICAL' + 'state': 'INFO' } for name, parms in parm_dict.items(): for parm in parms: @@ -91,9 +92,6 @@ def test_checkpoint_dir(self): def test_eval_batch_size(self): self.assertTrue(run_parms({'eval_batch_size': [1, 100]})) - def test_real_time_process(self): - self.assertTrue(run_parms({'real_time_process':[False, True]})) - def test_topk(self): settings = { 'metrics': ["Recall", "MRR", "NDCG", "Hit", "Precision"], @@ -105,7 +103,7 @@ def test_loss(self): settings = { 'metrics':["MAE", "RMSE", "LOGLOSS", "AUC"], 'valid_metric': 'auc', - 'eval_setting': 'RO_RS, uni100' + 'eval_args': {'split': {'RS': [0.8,0.1,0.1]}, 'order': 'RO', 'mode': 'uni100'} } self.assertTrue(run_parms({'topk':{None, 1}}, extra_dict=settings)) @@ -117,22 +115,17 @@ def test_metric(self): self.assertTrue(run_parms({'metrics':["Recall", ["Recall", "MRR", "NDCG", "Hit", "Precision"]]}, extra_dict=settings)) def test_split_ratio(self): - settings = { - 'leave_one_num':None - } + self.assertTrue(run_parms({'eval_args': [{'split': {'RS':[0.8,0.1,0.1]}}, {'split': {'RS':[16,2,2]}}]})) - self.assertTrue(run_parms({'split_ratio':[ # [0.8, 0.2], - [0.8, 0.1, 0.1], [16, 2, 2]]})) - - def test_leave_one_num(self): - settings = { - 'split_ratio':None - } + # def test_leave_one_num(self): + # settings = { + # 'split_ratio':None + # } - self.assertTrue(run_parms({'leave_one_num':[1, 2, 3]})) + # self.assertTrue(run_parms({'leave_one_num':[1, 2, 3]})) def test_group_by_user(self): - self.assertTrue(run_parms({'group_by_user':[True, False]})) + self.assertTrue(run_parms({'eval_args': [{'group_by': 'user'}, {'group_by': 'none'}]})) diff --git a/tests/data/kg_remap_id/kg_remap_id.inter b/tests/data/kg_remap_id/kg_remap_id.inter new file mode 100644 index 000000000..f8ab8a690 --- /dev/null +++ b/tests/data/kg_remap_id/kg_remap_id.inter @@ -0,0 +1,5 @@ +user_id:token item_id:token +ua ia +ub ib +uc ic +ud id \ No newline at end of file diff --git a/tests/data/kg_remap_id/kg_remap_id.kg b/tests/data/kg_remap_id/kg_remap_id.kg new file mode 100644 index 000000000..95f6d4189 --- /dev/null +++ b/tests/data/kg_remap_id/kg_remap_id.kg @@ -0,0 +1,5 @@ +head_id:token tail_id:token relation_id:token +eb ea ra +ec eb rb +ed ec rc +ee ed rd \ No newline at end of file diff --git a/tests/data/kg_remap_id/kg_remap_id.link b/tests/data/kg_remap_id/kg_remap_id.link new file mode 100644 index 000000000..3328eccb5 --- /dev/null +++ b/tests/data/kg_remap_id/kg_remap_id.link @@ -0,0 +1,5 @@ +item_id:token entity_id:token +ib eb +ic ec +id ed +ie ee \ No newline at end of file diff --git a/tests/data/seq_benchmark/seq_benchmark.test.inter b/tests/data/seq_benchmark/seq_benchmark.test.inter new file mode 100644 index 000000000..e58c2927d --- /dev/null +++ b/tests/data/seq_benchmark/seq_benchmark.test.inter @@ -0,0 +1,5 @@ +user_id:token item_id:token timestamp:float item_id_list:token_seq timestamp_list:float_seq +1 7 7 1 2 3 4 5 6 1 2 3 4 5 6 +1 8 8 1 2 3 4 5 6 7 1 2 3 4 5 6 7 +2 8 8 4 5 6 7 4 5 6 7 +3 6 6 4 5 4 5 diff --git a/tests/data/seq_benchmark/seq_benchmark.train.inter b/tests/data/seq_benchmark/seq_benchmark.train.inter new file mode 100644 index 000000000..24b14ddd8 --- /dev/null +++ b/tests/data/seq_benchmark/seq_benchmark.train.inter @@ -0,0 +1,8 @@ +user_id:token item_id:token timestamp:float item_id_list:token_seq timestamp_list:float_seq +1 2 2 1 1 +1 3 3 1 2 1 2 +1 4 4 1 2 3 1 2 3 +4 4 4 3 3 +2 5 5 4 4 +2 6 6 4 5 4 5 +3 5 5 4 4 diff --git a/tests/data/seq_benchmark/seq_benchmark.valid.inter b/tests/data/seq_benchmark/seq_benchmark.valid.inter new file mode 100644 index 000000000..ce15c506c --- /dev/null +++ b/tests/data/seq_benchmark/seq_benchmark.valid.inter @@ -0,0 +1,4 @@ +user_id:token item_id:token timestamp:float item_id_list:token_seq timestamp_list:float_seq +1 5 5 1 2 3 4 1 2 3 4 +1 6 6 1 2 3 4 5 1 2 3 4 5 +2 7 7 4 5 6 4 5 6 diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py index e79e86030..6f82b357f 100644 --- a/tests/data/test_dataloader.py +++ b/tests/data/test_dataloader.py @@ -4,9 +4,9 @@ # @Email : chenyushuo@ruc.edu.cn # UPDATE -# @Time : 2020/1/5 -# @Author : Yushuo Chen -# @email : chenyushuo@ruc.edu.cn +# @Time : 2020/1/5, 2021/7/1 +# @Author : Yushuo Chen, Xingyu Pan +# @email : chenyushuo@ruc.edu.cn, xy_pan@foxmail.com import logging import os @@ -37,9 +37,8 @@ def test_general_dataloader(self): 'dataset': 'general_dataloader', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_RS', + 'eval_args': {'split': {'RS': [0.8, 0.1, 0.1]}, 'order': 'TO', 'mode': 'none'}, 'training_neg_sample_num': 0, - 'split_ratio': [0.8, 0.1, 0.1], 'train_batch_size': train_batch_size, 'eval_batch_size': eval_batch_size, } @@ -54,8 +53,8 @@ def check_dataloader(data, item_list, batch_size): pr += batch_size check_dataloader(train_data, list(range(1, 41)), train_batch_size) - check_dataloader(valid_data, list(range(41, 46)), eval_batch_size) - check_dataloader(test_data, list(range(46, 51)), eval_batch_size) + check_dataloader(valid_data, list(range(41, 46)), max(eval_batch_size, 5)) + check_dataloader(test_data, list(range(46, 51)), max(eval_batch_size, 5)) def test_general_neg_sample_dataloader_in_pair_wise(self): train_batch_size = 6 @@ -65,9 +64,8 @@ def test_general_neg_sample_dataloader_in_pair_wise(self): 'dataset': 'general_dataloader', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_RS,full', 'training_neg_sample_num': 1, - 'split_ratio': [0.8, 0.1, 0.1], + 'eval_args': {'split': {'RS': [0.8, 0.1, 0.1]}, 'order': 'TO', 'mode': 'full'}, 'train_batch_size': train_batch_size, 'eval_batch_size': eval_batch_size, } @@ -93,9 +91,8 @@ def test_general_neg_sample_dataloader_in_point_wise(self): 'dataset': 'general_dataloader', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_RS,full', 'training_neg_sample_num': 1, - 'split_ratio': [0.8, 0.1, 0.1], + 'eval_args': {'split': {'RS': [0.8, 0.1, 0.1]}, 'order': 'TO', 'mode': 'full'}, 'train_batch_size': train_batch_size, 'eval_batch_size': eval_batch_size, } @@ -121,9 +118,8 @@ def test_general_full_dataloader(self): 'dataset': 'general_full_dataloader', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_RS,full', 'training_neg_sample_num': 1, - 'split_ratio': [0.8, 0.1, 0.1], + 'eval_args': {'split': {'RS': [0.8, 0.1, 0.1]}, 'order': 'TO', 'mode': 'full'}, 'train_batch_size': train_batch_size, 'eval_batch_size': eval_batch_size, } @@ -240,9 +236,8 @@ def test_general_uni100_dataloader_with_batch_size_in_101(self): 'dataset': 'general_uni100_dataloader', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_RS,uni100', 'training_neg_sample_num': 1, - 'split_ratio': [0.8, 0.1, 0.1], + 'eval_args': {'split': {'RS': [0.8, 0.1, 0.1]}, 'order': 'TO', 'mode': 'uni100'}, 'train_batch_size': train_batch_size, 'eval_batch_size': eval_batch_size, } @@ -312,9 +307,8 @@ def test_general_uni100_dataloader_with_batch_size_in_303(self): 'dataset': 'general_uni100_dataloader', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_RS,uni100', 'training_neg_sample_num': 1, - 'split_ratio': [0.8, 0.1, 0.1], + 'eval_args': {'split': {'RS': [0.8, 0.1, 0.1]}, 'order': 'TO', 'mode': 'uni100'}, 'train_batch_size': train_batch_size, 'eval_batch_size': eval_batch_size, } diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 3550f00d6..0b1a2830e 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -4,16 +4,16 @@ # @Email : chenyushuo@ruc.edu.cn # UPDATE -# @Time : 2020/1/3 -# @Author : Yushuo Chen -# @email : chenyushuo@ruc.edu.cn +# @Time : 2020/1/3, 2021/7/1, 2021/7/11 +# @Author : Yushuo Chen, Xingyu Pan, Yupeng Hou +# @email : chenyushuo@ruc.edu.cn, xy_pan@foxmail.com, houyupeng@ruc.edu.cn import logging import os import pytest -from recbole.config import Config, EvalSetting +from recbole.config import Config from recbole.data import create_dataset from recbole.utils import init_seed @@ -29,11 +29,7 @@ def new_dataset(config_dict=None, config_file_list=None): def split_dataset(config_dict=None, config_file_list=None): dataset = new_dataset(config_dict=config_dict, config_file_list=config_file_list) - config = dataset.config - es_str = [_.strip() for _ in config['eval_setting'].split(',')] - es = EvalSetting(config) - es.set_ordering_and_splitting(es_str[0]) - return dataset.build(es) + return dataset.build() class TestDataset: @@ -77,8 +73,8 @@ def test_filter_by_field_value_with_lowest_val(self): 'dataset': 'filter_by_field_value', 'data_path': current_path, 'load_col': None, - 'lowest_val': { - 'timestamp': 4, + 'val_interval': { + 'timestamp': "[4,inf)", }, } dataset = new_dataset(config_dict=config_dict) @@ -90,8 +86,8 @@ def test_filter_by_field_value_with_highest_val(self): 'dataset': 'filter_by_field_value', 'data_path': current_path, 'load_col': None, - 'highest_val': { - 'timestamp': 4, + 'val_interval': { + 'timestamp': "(-inf,4]", }, } dataset = new_dataset(config_dict=config_dict) @@ -103,8 +99,8 @@ def test_filter_by_field_value_with_equal_val(self): 'dataset': 'filter_by_field_value', 'data_path': current_path, 'load_col': None, - 'equal_val': { - 'rating': 0, + 'val_interval': { + 'rating': "[0,0]", }, } dataset = new_dataset(config_dict=config_dict) @@ -116,8 +112,8 @@ def test_filter_by_field_value_with_not_equal_val(self): 'dataset': 'filter_by_field_value', 'data_path': current_path, 'load_col': None, - 'not_equal_val': { - 'rating': 4, + 'val_interval': { + 'rating': "(-inf,4);(4,inf)", }, } dataset = new_dataset(config_dict=config_dict) @@ -129,11 +125,8 @@ def test_filter_by_field_value_in_same_field(self): 'dataset': 'filter_by_field_value', 'data_path': current_path, 'load_col': None, - 'lowest_val': { - 'timestamp': 3, - }, - 'highest_val': { - 'timestamp': 8, + 'val_interval': { + 'timestamp': "[3,8]", }, } dataset = new_dataset(config_dict=config_dict) @@ -145,14 +138,9 @@ def test_filter_by_field_value_in_different_field(self): 'dataset': 'filter_by_field_value', 'data_path': current_path, 'load_col': None, - 'lowest_val': { - 'timestamp': 3, - }, - 'highest_val': { - 'timestamp': 8, - }, - 'not_equal_val': { - 'rating': 4, + 'val_interval': { + 'timestamp': "[3,8]", + 'rating': "(-inf,4);(4,inf)", } } dataset = new_dataset(config_dict=config_dict) @@ -186,7 +174,7 @@ def test_filter_by_inter_num_in_min_user_inter_num(self): 'dataset': 'filter_by_inter_num', 'data_path': current_path, 'load_col': None, - 'min_user_inter_num': 2, + 'user_inter_num_interval': "[2,inf)", } dataset = new_dataset(config_dict=config_dict) assert dataset.user_num == 6 @@ -198,7 +186,7 @@ def test_filter_by_inter_num_in_min_item_inter_num(self): 'dataset': 'filter_by_inter_num', 'data_path': current_path, 'load_col': None, - 'min_item_inter_num': 2, + 'item_inter_num_interval': "[2,inf)", } dataset = new_dataset(config_dict=config_dict) assert dataset.user_num == 7 @@ -210,7 +198,7 @@ def test_filter_by_inter_num_in_max_user_inter_num(self): 'dataset': 'filter_by_inter_num', 'data_path': current_path, 'load_col': None, - 'max_user_inter_num': 2, + 'user_inter_num_interval': "(-inf,2]", } dataset = new_dataset(config_dict=config_dict) assert dataset.user_num == 6 @@ -222,7 +210,7 @@ def test_filter_by_inter_num_in_max_item_inter_num(self): 'dataset': 'filter_by_inter_num', 'data_path': current_path, 'load_col': None, - 'max_item_inter_num': 2, + 'item_inter_num_interval': "(-inf,2]", } dataset = new_dataset(config_dict=config_dict) assert dataset.user_num == 5 @@ -234,8 +222,8 @@ def test_filter_by_inter_num_in_min_inter_num(self): 'dataset': 'filter_by_inter_num', 'data_path': current_path, 'load_col': None, - 'min_user_inter_num': 2, - 'min_item_inter_num': 2, + 'user_inter_num_interval': "[2,inf)", + 'item_inter_num_interval': "[2,inf)", } dataset = new_dataset(config_dict=config_dict) assert dataset.user_num == 5 @@ -247,9 +235,8 @@ def test_filter_by_inter_num_in_complex_way(self): 'dataset': 'filter_by_inter_num', 'data_path': current_path, 'load_col': None, - 'max_user_inter_num': 3, - 'min_user_inter_num': 2, - 'min_item_inter_num': 2, + 'user_inter_num_interval': "[2,3]", + 'item_inter_num_interval': "[2,inf)", } dataset = new_dataset(config_dict=config_dict) assert dataset.user_num == 3 @@ -262,8 +249,8 @@ def test_rm_dup_by_first_and_filter_value(self): 'data_path': current_path, 'load_col': None, 'rm_dup_inter': 'first', - 'highest_val': { - 'rating': 4, + 'val_interval': { + 'rating': "(-inf,4]", }, } dataset = new_dataset(config_dict=config_dict) @@ -276,8 +263,8 @@ def test_rm_dup_by_last_and_filter_value(self): 'data_path': current_path, 'load_col': None, 'rm_dup_inter': 'last', - 'highest_val': { - 'rating': 4, + 'val_interval': { + 'rating': "(-inf,4]", }, } dataset = new_dataset(config_dict=config_dict) @@ -290,8 +277,8 @@ def test_rm_dup_and_filter_by_inter_num(self): 'data_path': current_path, 'load_col': None, 'rm_dup_inter': 'first', - 'min_user_inter_num': 2, - 'min_item_inter_num': 2, + 'user_inter_num_interval': "[2,inf)", + 'item_inter_num_interval': "[2,inf)", } dataset = new_dataset(config_dict=config_dict) assert len(dataset.inter_feat) == 4 @@ -304,11 +291,9 @@ def test_filter_value_and_filter_inter_by_ui(self): 'dataset': 'filter_value_and_filter_inter_by_ui', 'data_path': current_path, 'load_col': None, - 'highest_val': { - 'age': 2, - }, - 'not_equal_val': { - 'price': 2, + 'val_interval': { + 'age': "(-inf,2]", + 'price': "(-inf,2);(2,inf)", }, 'filter_inter_by_user_or_item': True, } @@ -323,13 +308,13 @@ def test_filter_value_and_inter_num(self): 'dataset': 'filter_value_and_inter_num', 'data_path': current_path, 'load_col': None, - 'highest_val': { - 'rating': 0, - 'age': 0, - 'price': 0, + 'val_interval': { + 'rating': "(-inf,0]", + 'age': "(-inf,0]", + 'price': "(-inf,0]", }, - 'min_user_inter_num': 2, - 'min_item_inter_num': 2, + 'user_inter_num_interval': "[2,inf)", + 'item_inter_num_interval': "[2,inf)", } dataset = new_dataset(config_dict=config_dict) assert len(dataset.inter_feat) == 4 @@ -343,8 +328,8 @@ def test_filter_inter_by_ui_and_inter_num(self): 'data_path': current_path, 'load_col': None, 'filter_inter_by_user_or_item': True, - 'min_user_inter_num': 2, - 'min_item_inter_num': 2, + 'user_inter_num_interval': "[2,inf)", + 'item_inter_num_interval': "[2,inf)", } dataset = new_dataset(config_dict=config_dict) assert len(dataset.inter_feat) == 4 @@ -357,7 +342,6 @@ def test_remap_id(self): 'dataset': 'remap_id', 'data_path': current_path, 'load_col': None, - 'fields_in_same_space': None, } dataset = new_dataset(config_dict=config_dict) user_list = dataset.token2id('user_id', ['ua', 'ub', 'uc', 'ud']) @@ -373,16 +357,14 @@ def test_remap_id(self): assert (dataset.inter_feat['user_list'][2] == [3, 4, 1]).all() assert (dataset.inter_feat['user_list'][3] == [5]).all() - def test_remap_id_with_fields_in_same_space(self): + def test_remap_id_with_alias(self): config_dict = { 'model': 'BPR', 'dataset': 'remap_id', 'data_path': current_path, 'load_col': None, - 'fields_in_same_space': [ - ['user_id', 'add_user', 'user_list'], - ['item_id', 'add_item'], - ], + 'alias_of_user_id': ['add_user', 'user_list'], + 'alias_of_item_id': ['add_item'], } dataset = new_dataset(config_dict=config_dict) user_list = dataset.token2id('user_id', ['ua', 'ub', 'uc', 'ud', 'ue', 'uf']) @@ -473,8 +455,7 @@ def test_TO_RS_811(self): 'dataset': 'build_dataset', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_RS', - 'split_ratio': [0.8, 0.1, 0.1], + 'eval_args':{'split': {'RS':[0.8, 0.1, 0.1]}, 'order':'TO','mode':'none'} } train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) assert (train_dataset.inter_feat['item_id'].numpy() == list(range(1, 17)) + [1] + [1] + [1] + [1, 2, 3] + @@ -490,8 +471,7 @@ def test_TO_RS_820(self): 'dataset': 'build_dataset', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_RS', - 'split_ratio': [0.8, 0.2, 0.0], + 'eval_args':{'split': {'RS':[0.8, 0.2, 0.0]}, 'order':'TO', 'mode':'none'} } train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) assert (train_dataset.inter_feat['item_id'].numpy() == list(range(1, 17)) + [1] + [1] + [1, 2] + [1, 2, 3, 4] + @@ -506,8 +486,7 @@ def test_TO_RS_802(self): 'dataset': 'build_dataset', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_RS', - 'split_ratio': [0.8, 0.0, 0.2], + 'eval_args':{'split': {'RS':[0.8, 0.0, 0.2]}, 'order':'TO', 'mode':'none'} } train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) assert (train_dataset.inter_feat['item_id'].numpy() == list(range(1, 17)) + [1] + [1] + [1, 2] + [1, 2, 3, 4] + @@ -522,8 +501,7 @@ def test_TO_LS(self): 'dataset': 'build_dataset', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'TO_LS', - 'leave_one_num': 2, + 'eval_args':{'split': {'LS': 2}, 'order':'TO', 'mode':'none'} } train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) assert (train_dataset.inter_feat['item_id'].numpy() == list(range(1, 19)) + [1] + [1] + [1] + [1, 2, 3] + @@ -539,8 +517,7 @@ def test_RO_RS_811(self): 'dataset': 'build_dataset', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'RO_RS', - 'split_ratio': [0.8, 0.1, 0.1], + 'eval_args':{'split': {'RS':[0.8, 0.1, 0.1]}, 'order':'RO', 'mode':'none'} } train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) assert len(train_dataset.inter_feat) == 16 + 1 + 1 + 1 + 3 + 7 + 8 + 9 @@ -553,8 +530,7 @@ def test_RO_RS_820(self): 'dataset': 'build_dataset', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'RO_RS', - 'split_ratio': [0.8, 0.2, 0.0], + 'eval_args':{'split': {'RS':[0.8, 0.2, 0.0]}, 'order':'RO', 'mode':'none'} } train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) assert len(train_dataset.inter_feat) == 16 + 1 + 1 + 2 + 4 + 8 + 8 + 9 @@ -567,8 +543,7 @@ def test_RO_RS_802(self): 'dataset': 'build_dataset', 'data_path': current_path, 'load_col': None, - 'eval_setting': 'RO_RS', - 'split_ratio': [0.8, 0.0, 0.2], + 'eval_args':{'split': {'RS':[0.8, 0.0, 0.2]}, 'order':'RO', 'mode':'none'} } train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) assert len(train_dataset.inter_feat) == 16 + 1 + 1 + 2 + 4 + 8 + 8 + 9 @@ -586,21 +561,34 @@ def test_seq_leave_one_out(self): 'training_neg_sample_num': 0 } train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) - assert (train_dataset.uid_list == [1, 1, 1, 1, 1, 2, 2, 3, 4]).all() - assert (train_dataset.item_list_index == [slice(0, 1), slice(0, 2), slice(0, 3), slice(0, 4), slice(0, 5), - slice(8, 9), slice(8, 10), slice(13, 14), slice(16, 17)]).all() - assert (train_dataset.target_index == [1, 2, 3, 4, 5, 9, 10, 14, 17]).all() - assert (train_dataset.item_list_length == [1, 2, 3, 4, 5, 1, 2, 1, 1]).all() - - assert (valid_dataset.uid_list == [1, 2]).all() - assert (valid_dataset.item_list_index == [slice(0, 6), slice(8, 11)]).all() - assert (valid_dataset.target_index == [6, 11]).all() - assert (valid_dataset.item_list_length == [6, 3]).all() - - assert (test_dataset.uid_list == [1, 2, 3]).all() - assert (test_dataset.item_list_index == [slice(0, 7), slice(8, 12), slice(13, 15)]).all() - assert (test_dataset.target_index == [7, 12, 15]).all() - assert (test_dataset.item_list_length == [7, 4, 2]).all() + assert (train_dataset.inter_feat[train_dataset.uid_field].numpy() == [1, 1, 1, 1, 1, 4, 2, 2, 3]).all() + assert (train_dataset.inter_feat[train_dataset.item_id_list_field][:,:5].numpy() == [ + [1, 0, 0, 0, 0], + [1, 2, 0, 0, 0], + [1, 2, 3, 0, 0], + [1, 2, 3, 4, 0], + [1, 2, 3, 4, 5], + [3, 0, 0, 0, 0], + [4, 0, 0, 0, 0], + [4, 5, 0, 0, 0], + [4, 0, 0, 0, 0]]).all() + assert (train_dataset.inter_feat[train_dataset.iid_field].numpy() == [2, 3, 4, 5, 6, 4, 5, 6, 5]).all() + assert (train_dataset.inter_feat[train_dataset.item_list_length_field].numpy() == [1, 2, 3, 4, 5, 1, 1, 2, 1]).all() + + assert (valid_dataset.inter_feat[valid_dataset.uid_field].numpy() == [1, 2]).all() + assert (valid_dataset.inter_feat[valid_dataset.item_id_list_field][:,:6].numpy() == [ + [1, 2, 3, 4, 5, 6], + [4, 5, 6, 0, 0, 0]]).all() + assert (valid_dataset.inter_feat[valid_dataset.iid_field].numpy() == [7, 7]).all() + assert (valid_dataset.inter_feat[valid_dataset.item_list_length_field].numpy() == [6, 3]).all() + + assert (test_dataset.inter_feat[test_dataset.uid_field].numpy() == [1, 2, 3]).all() + assert (test_dataset.inter_feat[test_dataset.item_id_list_field][:,:7].numpy() == [ + [1, 2, 3, 4, 5, 6, 7], + [4, 5, 6, 7, 0, 0, 0], + [4, 5, 0, 0, 0, 0, 0]]).all() + assert (test_dataset.inter_feat[test_dataset.iid_field].numpy() == [8, 8, 6]).all() + assert (test_dataset.inter_feat[test_dataset.item_list_length_field].numpy() == [7, 4, 2]).all() assert (train_dataset.inter_matrix().toarray() == [ [0., 0., 0., 0., 0., 0., 0., 0., 0.], @@ -611,19 +599,121 @@ def test_seq_leave_one_out(self): ]).all() assert (valid_dataset.inter_matrix().toarray() == [ [0., 0., 0., 0., 0., 0., 0., 0., 0.], - [0., 1., 1., 1., 1., 1., 1., 1., 0.], - [0., 0., 0., 0., 1., 1., 1., 1., 0.], - [0., 0., 0., 0., 1., 1., 0., 0., 0.], - [0., 0., 0., 1., 1., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 1., 0.], + [0., 0., 0., 0., 0., 0., 0., 1., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0.] ]).all() assert (test_dataset.inter_matrix().toarray() == [ [0., 0., 0., 0., 0., 0., 0., 0., 0.], - [0., 1., 1., 1., 1., 1., 1., 1., 1.], - [0., 0., 0., 0., 1., 1., 1., 1., 1.], - [0., 0., 0., 0., 1., 1., 1., 0., 0.], - [0., 0., 0., 1., 1., 0., 0., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 1.], + [0., 0., 0., 0., 0., 0., 0., 0., 1.], + [0., 0., 0., 0., 0., 0., 1., 0., 0.], + [0., 0., 0., 0., 0., 0., 0., 0., 0.] ]).all() + def test_seq_split_by_ratio(self): + config_dict = { + 'model': 'GRU4Rec', + 'dataset': 'seq_dataset', + 'data_path': current_path, + 'load_col': None, + 'training_neg_sample_num': 0, + 'eval_args': { + 'split': {'RS': [0.3, 0.3, 0.4]}, + 'order': 'TO' + } + } + train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) + assert (train_dataset.inter_feat[train_dataset.uid_field].numpy() == [1, 1, 1, 4, 2, 2, 3]).all() + assert (train_dataset.inter_feat[train_dataset.item_id_list_field][:,:3].numpy() == [ + [1, 0, 0], + [1, 2, 0], + [1, 2, 3], + [3, 0, 0], + [4, 0, 0], + [4, 5, 0], + [4, 0, 0]]).all() + assert (train_dataset.inter_feat[train_dataset.iid_field].numpy() == [2, 3, 4, 4, 5, 6, 5]).all() + assert (train_dataset.inter_feat[train_dataset.item_list_length_field].numpy() == [1, 2, 3, 1, 1, 2, 1]).all() + + assert (valid_dataset.inter_feat[valid_dataset.uid_field].numpy() == [1, 1, 2]).all() + assert (valid_dataset.inter_feat[valid_dataset.item_id_list_field][:,:5].numpy() == [ + [1, 2, 3, 4, 0], + [1, 2, 3, 4, 5], + [4, 5, 6, 0, 0]]).all() + assert (valid_dataset.inter_feat[valid_dataset.iid_field].numpy() == [5, 6, 7]).all() + assert (valid_dataset.inter_feat[valid_dataset.item_list_length_field].numpy() == [4, 5, 3]).all() + + assert (test_dataset.inter_feat[test_dataset.uid_field].numpy() == [1, 1, 2, 3]).all() + assert (test_dataset.inter_feat[test_dataset.item_id_list_field][:,:7].numpy() == [ + [1, 2, 3, 4, 5, 6, 0], + [1, 2, 3, 4, 5, 6, 7], + [4, 5, 6, 7, 0, 0, 0], + [4, 5, 0, 0, 0, 0, 0]]).all() + assert (test_dataset.inter_feat[test_dataset.iid_field].numpy() == [7, 8, 8, 6]).all() + assert (test_dataset.inter_feat[test_dataset.item_list_length_field].numpy() == [6, 7, 4, 2]).all() + + def test_seq_benchmark(self): + config_dict = { + 'model': 'GRU4Rec', + 'dataset': 'seq_benchmark', + 'data_path': current_path, + 'load_col': None, + 'training_neg_sample_num': 0, + 'benchmark_filename': ['train', 'valid', 'test'], + 'alias_of_item_id': ['item_id_list'] + } + train_dataset, valid_dataset, test_dataset = split_dataset(config_dict=config_dict) + assert (train_dataset.inter_feat[train_dataset.uid_field].numpy() == [1, 1, 1, 2, 3, 3, 4]).all() + assert (train_dataset.inter_feat[train_dataset.item_id_list_field][:,:3].numpy() == [ + [8, 0, 0], + [8, 1, 0], + [8, 1, 2], + [2, 0, 0], + [3, 0, 0], + [3, 4, 0], + [3, 0, 0]]).all() + assert (train_dataset.inter_feat[train_dataset.iid_field].numpy() == [1, 2, 3, 3, 4, 5, 4]).all() + assert (train_dataset.inter_feat[train_dataset.item_list_length_field].numpy() == [1, 2, 3, 1, 1, 2, 1]).all() + + assert (valid_dataset.inter_feat[valid_dataset.uid_field].numpy() == [1, 1, 3]).all() + assert (valid_dataset.inter_feat[valid_dataset.item_id_list_field][:,:5].numpy() == [ + [8, 1, 2, 3, 0], + [8, 1, 2, 3, 4], + [3, 4, 5, 0, 0]]).all() + assert (valid_dataset.inter_feat[valid_dataset.iid_field].numpy() == [4, 5, 6]).all() + assert (valid_dataset.inter_feat[valid_dataset.item_list_length_field].numpy() == [4, 5, 3]).all() + + assert (test_dataset.inter_feat[test_dataset.uid_field].numpy() == [1, 1, 3, 4]).all() + assert (test_dataset.inter_feat[test_dataset.item_id_list_field][:,:7].numpy() == [ + [8, 1, 2, 3, 4, 5, 0], + [8, 1, 2, 3, 4, 5, 6], + [3, 4, 5, 6, 0, 0, 0], + [3, 4, 0, 0, 0, 0, 0]]).all() + assert (test_dataset.inter_feat[test_dataset.iid_field].numpy() == [6, 7, 7, 5]).all() + assert (test_dataset.inter_feat[test_dataset.item_list_length_field].numpy() == [6, 7, 4, 2]).all() + + +class TestKGDataset: + def test_kg_remap_id(self): + config_dict = { + 'model': 'KGAT', + 'dataset': 'kg_remap_id', + 'data_path': current_path, + 'load_col': None, + } + dataset = new_dataset(config_dict=config_dict) + print(dataset.field2id_token['entity_id']) + item_list = dataset.token2id('item_id', ['ia', 'ib', 'ic', 'id']) + entity_list = dataset.token2id('entity_id', ['eb', 'ec', 'ed', 'ee', 'ea']) + assert (item_list == [1, 2, 3, 4]).all() + assert (entity_list == [2, 3, 4, 5, 6]).all() + assert (dataset.inter_feat['user_id'] == [1, 2, 3, 4]).all() + assert (dataset.inter_feat['item_id'] == [1, 2, 3, 4]).all() + assert (dataset.kg_feat['head_id'] == [2, 3, 4, 5]).all() + assert (dataset.kg_feat['tail_id'] == [6, 2, 3, 4]).all() + if __name__ == "__main__": pytest.main() diff --git a/tests/evaluation_setting/test_evaluation_setting.py b/tests/evaluation_setting/test_evaluation_setting.py index 059f6f010..540074e1b 100644 --- a/tests/evaluation_setting/test_evaluation_setting.py +++ b/tests/evaluation_setting/test_evaluation_setting.py @@ -4,9 +4,9 @@ # @Email : slmu@ruc.edu.cn # UPDATE: -# @Time : 2020/11/17 +# @Time : 2021/7/1 # @Author : Xingyu Pan -# @Email : panxy@ruc.edu.cn +# @Email : xy_pan@foxmail.com import os import unittest @@ -20,251 +20,85 @@ class TestGeneralRecommender(unittest.TestCase): def test_rols_full(self): config_dict = { - 'eval_setting': 'RO_LS,full', + 'eval_args': {'split': {'LS': 2}, 'order': 'RO', 'mode': 'full'}, 'model': 'BPR', } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) - ''' - config_dict = { - 'eval_setting': 'RO_LS,full', - 'model': 'NeuMF', - } - objective_function(config_dict=config_dict, - config_file_list=config_file_list, saved=False) - config_dict = { - 'eval_setting': 'RO_LS,full', - 'model': 'FISM', - } - objective_function(config_dict=config_dict, - config_file_list=config_file_list, saved=False) - config_dict = { - 'eval_setting': 'RO_LS,full', - 'model': 'LightGCN', - } - objective_function(config_dict=config_dict, - config_file_list=config_file_list, saved=False) - ''' def test_tols_full(self): config_dict = { - 'eval_setting': 'TO_LS,full', + 'eval_args': {'split': {'LS': 2}, 'order': 'TO', 'mode': 'full'}, 'model': 'BPR', } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) - ''' - config_dict = { - 'eval_setting': 'TO_LS,full', - 'model': 'NeuMF', - } - objective_function(config_dict=config_dict, - config_file_list=config_file_list, saved=False) - config_dict = { - 'eval_setting': 'TO_LS,full', - 'model': 'FISM', - } - objective_function(config_dict=config_dict, - config_file_list=config_file_list, saved=False) - config_dict = { - 'eval_setting': 'TO_LS,full', - 'model': 'LightGCN', - } - objective_function(config_dict=config_dict, - config_file_list=config_file_list, saved=False) - ''' def test_tors_full(self): config_dict = { - 'eval_setting': 'TO_RS,full', + 'eval_args': {'split': {'RS':[0.8,0.1,0.1]}, 'order': 'RO', 'mode': 'full'}, 'model': 'BPR', } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_RS,full', - # 'model': 'NeuMF', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_RS,full', - # 'model': 'FISM', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_RS,full', - # 'model': 'LightGCN', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) def test_rors_uni100(self): config_dict = { - 'eval_setting': 'RO_RS,uni100', + 'eval_args': {'split': {'RS':[0.8,0.1,0.1]}, 'order': 'RO', 'mode': 'uni100'}, 'model': 'BPR', } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'RO_RS,uni100', - # 'model': 'NeuMF', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'RO_RS,uni100', - # 'model': 'FISM', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'RO_RS,uni100', - # 'model': 'LightGCN', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) def test_tols_uni100(self): config_dict = { 'eval_setting': 'TO_LS,uni100', + 'eval_args': {'split': {'LS': 2}, 'order': 'TO', 'mode': 'full'}, 'model': 'BPR', } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_LS,uni100', - # 'model': 'NeuMF', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_LS,uni100', - # 'model': 'FISM', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_LS,uni100', - # 'model': 'LightGCN', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) + def test_rols_uni100(self): config_dict = { - 'eval_setting': 'RO_LS,uni100', + 'eval_args': {'split': {'LS': 2}, 'order': 'RO', 'mode': 'uni100'}, 'model': 'BPR', } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'RO_LS,uni100', - # 'model': 'NeuMF', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'RO_LS,uni100', - # 'model': 'FISM', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'RO_LS,uni100', - # 'model': 'LightGCN', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) def test_tors_uni100(self): config_dict = { - 'eval_setting': 'TO_RS,uni100', + 'eval_args': {'split': {'RS':[0.8,0.1,0.1]}, 'order': 'TO', 'mode': 'uni100'}, 'model': 'BPR', } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_RS,uni100', - # 'model': 'NeuMF', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_RS,uni100', - # 'model': 'FISM', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_RS,uni100', - # 'model': 'LightGCN', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) class TestContextRecommender(unittest.TestCase): def test_tors(self): config_dict = { - 'eval_setting': 'TO_RS', + 'eval_args': {'split': {'RS':[0.8,0.1,0.1]}, 'order': 'TO', 'mode': 'none'}, 'threshold': {'rating': 4}, 'model': 'FM', } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_RS', - # 'model': 'DeepFM', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_RS', - # 'model': 'DSSM', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_RS', - # 'model': 'AutoInt', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) + class TestSequentialRecommender(unittest.TestCase): def test_tols_uni100(self): config_dict = { - 'eval_setting': 'TO_LS,uni100', + 'eval_args': {'split': {'LS': 2}, 'order': 'TO', 'mode': 'uni100'}, 'model': 'FPMC', } objective_function(config_dict=config_dict, config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_LS,uni100', - # 'model': 'SASRec', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_LS,uni100', - # 'model': 'GRU4RecF', - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - # config_dict = { - # 'eval_setting': 'TO_LS,uni100', - # 'model': 'Caser', - # 'MAX_ITEM_LIST_LENGTH': 10, - # 'reproducibility': False, - # } - # objective_function(config_dict=config_dict, - # config_file_list=config_file_list, saved=False) - if __name__ == '__main__': unittest.main() diff --git a/tests/metrics/test_rank_metrics.py b/tests/metrics/test_rank_metrics.py index 6901a3763..74851a64d 100644 --- a/tests/metrics/test_rank_metrics.py +++ b/tests/metrics/test_rank_metrics.py @@ -3,6 +3,10 @@ # @Author : Zhichao Feng # @email : fzcbupt@gmail.com +# UPDATE: +# @Time : 2021/7/1 +# @Author : Xingyu Pan +# @Email : xy_pan@foxmail.com import os import sys @@ -14,6 +18,8 @@ from recbole.evaluator import metrics_dict, Collector parameters_dict = { + 'model': 'BPR', + 'eval_args': {'split':{'RS':[0.8,0.1,0.1]}, 'order': 'RO', 'mode': 'uni100'}, 'metric_decimal_place': 4, } diff --git a/tests/metrics/test_topk_metrics.py b/tests/metrics/test_topk_metrics.py index 87557354e..4726cc1c7 100644 --- a/tests/metrics/test_topk_metrics.py +++ b/tests/metrics/test_topk_metrics.py @@ -138,6 +138,14 @@ def test_shannonentropy(self): -np.mean([1/15*np.log(1/15), 2/15*np.log(2/15), 3/15*np.log(3/15), 2/15*np.log(2/15), 4/15*np.log(4/15), 1/15*np.log(1/15), 2/15*np.log(2/15)])) + def test_tailpercentage(self): + name = 'tailpercentage' + Metric = metrics_dict[name](config) + self.assertEqual( + Metric.metric_info(Metric.get_tail(item_matrix, item_count)).tolist(), + np.array([[0 / 1, 0 / 2, 0 / 3], [0 / 1, 0 / 2, 0 / 3], [0 / 1, 0 / 2, 0 / 3], [1 / 1, 1 / 2, 1 / 3], + [0 / 1, 0 / 2, 0 / 3]]).tolist()) + if __name__ == "__main__": unittest.main() diff --git a/tests/model/test_model.yaml b/tests/model/test_model.yaml index 4604e8a65..6860c88fe 100644 --- a/tests/model/test_model.yaml +++ b/tests/model/test_model.yaml @@ -38,18 +38,11 @@ load_col: unload_col: ~ -# Filtering -max_user_inter_num: ~ -min_user_inter_num: ~ -max_item_inter_num: ~ -min_item_inter_num: ~ -lowest_val: ~ -highest_val: ~ -equal_val: ~ -not_equal_val: ~ - # Preprocessing -fields_in_same_space: ~ +alias_of_user_id: ~ +alias_of_item_id: ~ +alias_of_entity_id: ~ +alias_of_relation_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: True