diff --git a/docs/source/user_guide/data/atomic_files.rst b/docs/source/user_guide/data/atomic_files.rst index 9e74dd4d7..ebac34ce3 100644 --- a/docs/source/user_guide/data/atomic_files.rst +++ b/docs/source/user_guide/data/atomic_files.rst @@ -141,4 +141,4 @@ For example, if you want to map the tokens of ``ent_id`` into the same space of # inter/user/item/...: As usual ent: [ent_id, ent_emb] - fields_in_same_space: [[ent_id, entity_id]] + alias_of_entity_id: [ent_id] diff --git a/docs/source/user_guide/data/data_args.rst b/docs/source/user_guide/data/data_args.rst index 9c310f443..97ce58835 100644 --- a/docs/source/user_guide/data/data_args.rst +++ b/docs/source/user_guide/data/data_args.rst @@ -88,7 +88,10 @@ Filter by number of interactions Preprocessing ----------------- -- ``fields_in_same_space (list)`` : List of spaces. Space is a list of string similar to the fields' names. The fields in the same space will be remapped into the same index system. Note that if you want to make some fields remapped in the same space with entities, then just set ``fields_in_same_space = [entity_id, xxx, ...]``. (if ``ENTITY_ID_FIELD != 'entity_id'``, then change the ``'entity_id'`` in the above example.) Defaults to ``None``. +- ``alias_of_user_id (list)``: List of fields' names, which will be remapped into the same index system with ``USER_ID_FIELD``. Defaults to ``None``. +- ``alias_of_item_id (list)``: List of fields' names, which will be remapped into the same index system with ``ITEM_ID_FIELD``. Defaults to ``None``. +- ``alias_of_entity_id (list)``: List of fields' names, which will be remapped into the same index system with ``ENTITY_ID_FIELD``, ``HEAD_ENTITY_ID_FIELD`` and ``TAIL_ENTITY_ID_FIELD``. Defaults to ``None``. +- ``alias_of_relation_id (list)``: List of fields' names, which will be remapped into the same index system with ``RELATION_ID_FIELD``. Defaults to ``None``. - ``preload_weight (dict)`` : Has the format ``{k (str): v (float)}, ...``. ``k`` if a token field, representing the IDs of each row of preloaded weight matrix. ``v`` is a float like fields. Each pair of ``u`` and ``v`` should be from the same atomic file. This arg can be used to load pretrained vectors. Defaults to ``None``. - ``normalize_field (list)`` : List of filed names to be normalized. Note that only float like fields can be normalized. Defaults to ``None``. - ``normalize_all (bool)`` : Normalize all the float like fields if ``True``. Defaults to ``True``. diff --git a/docs/source/user_guide/model/sequential/gru4reckg.rst b/docs/source/user_guide/model/sequential/gru4reckg.rst index f1653cabf..69dacdaa4 100644 --- a/docs/source/user_guide/model/sequential/gru4reckg.rst +++ b/docs/source/user_guide/model/sequential/gru4reckg.rst @@ -46,9 +46,7 @@ And then: kg: [head_id, relation_id, tail_id] link: [item_id, entity_id] ent_feature: [ent_id, ent_vec] - fields_in_same_space: [ - [ent_id, entity_id] - ] + alias_of_entity_id: [ent_id] preload_weight: ent_id: ent_vec additional_feat_suffix: [ent_feature] diff --git a/docs/source/user_guide/model/sequential/ksr.rst b/docs/source/user_guide/model/sequential/ksr.rst index 8085e71b4..ff24e2184 100644 --- a/docs/source/user_guide/model/sequential/ksr.rst +++ b/docs/source/user_guide/model/sequential/ksr.rst @@ -58,10 +58,8 @@ And then: link: [item_id, entity_id] ent_feature: [ent_id, ent_vec] rel_feature: [rel_id, rel_vec] - fields_in_same_space: [ - [ent_id, entity_id] - [rel_id, relation_id] - ] + alias_of_entity_id: [ent_id] + alias_of_relation_id: [rel_id] preload_weight: ent_id: ent_vec rel_id: rel_vec diff --git a/docs/source/user_guide/usage/load_pretrained_embedding.rst b/docs/source/user_guide/usage/load_pretrained_embedding.rst index c8f3c010a..2e50a31ed 100644 --- a/docs/source/user_guide/usage/load_pretrained_embedding.rst +++ b/docs/source/user_guide/usage/load_pretrained_embedding.rst @@ -22,9 +22,9 @@ Secondly, update the args as (suppose that ``USER_ID_FIELD: user_id``): load_col: # inter/user/item/...: As usual useremb: [uid, user_emb] - fields_in_same_space: [[uid, user_id]] + alias_of_user_id: [uid] preload_weight: - uid: user_emb + uid: user_emb Then, this additional embedding feature file will be loaded into the :class:`Dataset` object. These new features can be accessed as following: @@ -39,6 +39,6 @@ In your model, user embedding matrix can be initialized by your pre-trained embe class YourModel(GeneralRecommender): def __init__(self, config, dataset): - pretrained_user_emb = dataset.get_preload_weight('uid') - self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb)) + pretrained_user_emb = dataset.get_preload_weight('uid') + self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb)) diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index 4a8ab981e..d5ea9f6d6 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -105,6 +105,7 @@ def _from_scratch(self): self._get_preset() self._get_field_from_config() self._load_data(self.dataset_name, self.dataset_path) + self._init_alias() self._data_processing() def _get_preset(self): @@ -117,6 +118,7 @@ def _get_preset(self): self.field2id_token = {} self.field2token_id = {} self.field2seqlen = self.config['seq_len'] or {} + self.alias = {} self._preloaded_weight = {} self.benchmark_filename_list = self.config['benchmark_filename'] @@ -441,6 +443,34 @@ def _load_feat(self, filepath, source): self.field2seqlen[field] = max(map(len, df[field].values)) return df + def _set_alias(self, alias_name, default_value): + alias = self.config[f'alias_of_{alias_name}'] or [] + alias = np.array(default_value + alias) + _, idx = np.unique(alias, return_index=True) + self.alias[alias_name] = alias[np.sort(idx)] + + def _init_alias(self): + """Set :attr:`alias_of_user_id` and :attr:`alias_of_item_id`. And set :attr:`_rest_fields`. + """ + self._set_alias('user_id', [self.uid_field]) + self._set_alias('item_id', [self.iid_field]) + + for alias_name_1, alias_1 in self.alias.items(): + for alias_name_2, alias_2 in self.alias.items(): + if alias_name_1 != alias_name_2: + intersect = np.intersect1d(alias_1, alias_2, assume_unique=True) + if len(intersect) > 0: + raise ValueError(f'`alias_of_{alias_name_1}` and `alias_of_{alias_name_2}` ' + f'should not have the same field {list(intersect)}.') + + self._rest_fields = self.token_like_fields + for alias_name, alias in self.alias.items(): + isin = np.isin(alias, self._rest_fields, assume_unique=True) + if isin.all() is False: + raise ValueError(f'`alias_of_{alias_name}` should not contain ' + f'non-token-like field {list(alias[~isin])}.') + self._rest_fields = np.setdiff1d(self._rest_fields, alias, assume_unique=True) + def _user_item_feat_preparation(self): """Sort :attr:`user_feat` and :attr:`item_feat` by ``user_id`` or ``item_id``. Missing values will be filled later. @@ -860,41 +890,7 @@ def _set_label_by_threshold(self): raise ValueError(f'Field [{field}] not in inter_feat.') self._del_col(self.inter_feat, field) - def _get_fields_in_same_space(self): - """Parsing ``config['fields_in_same_space']``. See :doc:`../user_guide/data/data_args` for detail arg setting. - - Note: - - Each field can only exist ONCE in ``config['fields_in_same_space']``. - - user_id and item_id can not exist in ``config['fields_in_same_space']``. - - only token-like fields can exist in ``config['fields_in_same_space']``. - """ - fields_in_same_space = self.config['fields_in_same_space'] or [] - fields_in_same_space = [set(_) for _ in fields_in_same_space] - additional = [] - token_like_fields = self.token_like_fields - for field in token_like_fields: - count = 0 - for field_set in fields_in_same_space: - if field in field_set: - count += 1 - if count == 0: - additional.append({field}) - elif count == 1: - continue - else: - raise ValueError(f'Field [{field}] occurred in `fields_in_same_space` more than one time.') - - for field_set in fields_in_same_space: - if self.uid_field in field_set and self.iid_field in field_set: - raise ValueError('uid_field and iid_field can\'t in the same ID space') - for field in field_set: - if field not in token_like_fields: - raise ValueError(f'Field [{field}] is not a token-like field.') - - fields_in_same_space.extend(additional) - return fields_in_same_space - - def _get_remap_list(self, field_set): + def _get_remap_list(self, field_list): """Transfer set of fields in the same remapping space into remap list. If ``uid_field`` or ``iid_field`` in ``field_set``, @@ -902,7 +898,7 @@ def _get_remap_list(self, field_set): then field in :attr:`user_feat` or :attr:`item_feat` will be remapped next, finally others. Args: - field_set (set): Set of fields in the same remapping space + field_list (numpy.ndarray): List of fields in the same remapping space. Returns: list: @@ -912,29 +908,23 @@ def _get_remap_list(self, field_set): They will be concatenated in order, and remapped together. """ + remap_list = [] - for field, feat in zip([self.uid_field, self.iid_field], [self.user_feat, self.item_feat]): - if field in field_set: - field_set.remove(field) - remap_list.append((self.inter_feat, field, FeatureType.TOKEN)) - if feat is not None: - remap_list.append((feat, field, FeatureType.TOKEN)) - for field in field_set: - source = self.field2source[field] - if isinstance(source, FeatureSource): - source = source.value - feat = getattr(self, f'{source}_feat') + for field in field_list: ftype = self.field2type[field] - remap_list.append((feat, field, ftype)) + for feat in self.field2feats(field): + remap_list.append((feat, field, ftype)) return remap_list def _remap_ID_all(self): - """Get ``config['fields_in_same_space']`` firstly, and remap each. + """Remap all token-like fields. """ - fields_in_same_space = self._get_fields_in_same_space() - self.logger.debug(set_color('fields_in_same_space', 'blue') + f': {fields_in_same_space}') - for field_set in fields_in_same_space: - remap_list = self._get_remap_list(field_set) + for alias in self.alias.values(): + remap_list = self._get_remap_list(alias) + self._remap(remap_list) + + for field in self._rest_fields: + remap_list = self._get_remap_list(np.array([field])) self._remap(remap_list) def _concat_remaped_tokens(self, remap_list): @@ -1087,6 +1077,24 @@ def copy_field_property(self, dest_field, source_field): self.field2source[dest_field] = self.field2source[source_field] self.field2seqlen[dest_field] = self.field2seqlen[source_field] + def field2feats(self, field): + if field not in self.field2source: + raise ValueError(f'Field [{field}] not defined in dataset.') + if field == self.uid_field: + feats = [self.inter_feat] + if self.user_feat is not None: + feats.append(self.user_feat) + elif field == self.iid_field: + feats = [self.inter_feat] + if self.item_feat is not None: + feats.append(self.item_feat) + else: + source = self.field2source[field] + if not isinstance(source, str): + source = source.value + feats = [getattr(self, f'{source}_feat')] + return feats + def token2id(self, field, tokens): """Map external tokens to internal ids. diff --git a/recbole/data/dataset/kg_dataset.py b/recbole/data/dataset/kg_dataset.py index 5bf4ab1c9..762141dd7 100644 --- a/recbole/data/dataset/kg_dataset.py +++ b/recbole/data/dataset/kg_dataset.py @@ -59,7 +59,7 @@ class KnowledgeBasedDataset(Dataset): Note: :attr:`entity_field` doesn't exist exactly. It's only a symbol, - representing entity features. E.g. it can be written into ``config['fields_in_same_space']``. + representing entity features. ``[UI-Relation]`` is a special relation token. """ @@ -67,10 +67,6 @@ class KnowledgeBasedDataset(Dataset): def __init__(self, config): super().__init__(config) - def _get_preset(self): - super()._get_preset() - self.field2ent_level = {} - def _get_field_from_config(self): super()._get_field_from_config() @@ -84,10 +80,6 @@ def _get_field_from_config(self): self.logger.debug(set_color('relation_field', 'blue') + f': {self.relation_field}') self.logger.debug(set_color('entity_field', 'blue') + f': {self.entity_field}') - def _data_processing(self): - self._set_field2ent_level() - super()._data_processing() - def _data_filtering(self): super()._data_filtering() self._filter_link() @@ -157,9 +149,6 @@ def _build_feat_name_list(self): feat_name_list.append('kg_feat') return feat_name_list - def _restore_saved_dataset(self, saved_dataset): - raise NotImplementedError() - def save(self, filepath): raise NotImplementedError() @@ -197,192 +186,86 @@ def _check_link(self, link): assert self.entity_field in link, link_warn_message.format(self.entity_field) assert self.iid_field in link, link_warn_message.format(self.iid_field) - def _get_fields_in_same_space(self): - """Parsing ``config['fields_in_same_space']``. See :doc:`../user_guide/data/data_args` for detail arg setting. - - Note: - - Each field can only exist ONCE in ``config['fields_in_same_space']``. - - user_id and item_id can not exist in ``config['fields_in_same_space']``. - - only token-like fields can exist in ``config['fields_in_same_space']``. - - ``head_entity_id`` and ``target_entity_id`` should be remapped with ``item_id``. - """ - fields_in_same_space = super()._get_fields_in_same_space() - fields_in_same_space = [_ for _ in fields_in_same_space if not self._contain_ent_field(_)] - ent_fields = self._get_ent_fields_in_same_space() - for field_set in fields_in_same_space: - if self.iid_field in field_set: - field_set.update(ent_fields) - return fields_in_same_space - - def _contain_ent_field(self, field_set): - """Return True if ``field_set`` contains entity fields. - """ - flag = False - flag |= self.head_entity_field in field_set - flag |= self.tail_entity_field in field_set - flag |= self.entity_field in field_set - return flag - - def _get_ent_fields_in_same_space(self): - """Return ``field_set`` that should be remapped together with entities. - """ - fields_in_same_space = super()._get_fields_in_same_space() - - ent_fields = {self.head_entity_field, self.tail_entity_field} - for field_set in fields_in_same_space: - if self._contain_ent_field(field_set): - field_set = self._remove_ent_field(field_set) - ent_fields.update(field_set) - self.logger.debug(set_color('ent_fields', 'blue') + f': {fields_in_same_space}') - return ent_fields - - def _remove_ent_field(self, field_set): - """Delete entity fields from ``field_set``. - """ - for field in [self.head_entity_field, self.tail_entity_field, self.entity_field]: - if field in field_set: - field_set.remove(field) - return field_set - - def _set_field2ent_level(self): - """For fields that remapped together with ``item_id``, - set their levels as ``rec``, otherwise as ``ent``. - """ - fields_in_same_space = self._get_fields_in_same_space() - for field_set in fields_in_same_space: - if self.iid_field in field_set: - for field in field_set: - self.field2ent_level[field] = 'rec' - ent_fields = self._get_ent_fields_in_same_space() - for field in ent_fields: - self.field2ent_level[field] = 'ent' - - def _fields_by_ent_level(self, ent_level): - """Given ``ent_level``, return all the field name of this level. - """ - ret = [] - for field in self.field2ent_level: - if self.field2ent_level[field] == ent_level: - ret.append(field) - return ret - - @property - def rec_level_ent_fields(self): - """Get entity fields remapped together with ``item_id``. - - Returns: - list: List of field names. + def _init_alias(self): + """Add :attr:`alias_of_entity_id`, :attr:`alias_of_relation_id` and update :attr:`_rest_fields`. """ - return self._fields_by_ent_level('rec') + self._set_alias('entity_id', [self.head_entity_field, self.tail_entity_field]) + self._set_alias('relation_id', [self.relation_field]) - @property - def ent_level_ent_fields(self): - """Get entity fields remapped together with ``entity_id``. + super()._init_alias() - Returns: - list: List of field names. - """ - return self._fields_by_ent_level('ent') - - def _remap_entities_by_link(self): - """Map entity tokens from fields in ``ent`` level - to item tokens according to ``.link``. - """ - for ent_field in self.ent_level_ent_fields: - source = self.field2source[ent_field] - if not isinstance(source, str): - source = source.value - feat = getattr(self, f'{source}_feat') - entity_list = feat[ent_field].values - for i, entity_id in enumerate(entity_list): - if entity_id in self.entity2item: - entity_list[i] = self.entity2item[entity_id] - feat[ent_field] = entity_list + self._rest_fields = np.setdiff1d(self._rest_fields, [self.entity_field], assume_unique=True) def _get_rec_item_token(self): """Get set of entity tokens from fields in ``rec`` level. """ - field_set = set(self.rec_level_ent_fields) - remap_list = self._get_remap_list(field_set) + remap_list = self._get_remap_list(self.alias['item_id']) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) def _get_entity_token(self): """Get set of entity tokens from fields in ``ent`` level. """ - field_set = set(self.ent_level_ent_fields) - remap_list = self._get_remap_list(field_set) + remap_list = self._get_remap_list(self.alias['entity_id']) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) - def _reset_ent_remapID(self, field, new_id_token): - token2id = {} - for i, token in enumerate(new_id_token): - token2id[token] = i - idmap = {} - for i, token in enumerate(self.field2id_token[field]): - if token not in token2id: - continue - new_idx = token2id[token] - idmap[i] = new_idx - source = self.field2source[field] - if not isinstance(source, str): - source = source.value - if source == 'item_id': - feats = [self.inter_feat] - if self.item_feat is not None: - feats.append(self.item_feat) - else: - feats = [getattr(self, f'{source}_feat')] - for feat in feats: - old_idx = feat[field].values - new_idx = np.array([idmap[_] for _ in old_idx]) - feat[field] = new_idx - - def _sort_remaped_entities(self, item_tokens): - item2order = {} - for token in self.field2id_token[self.iid_field]: - if token == '[PAD]': - item2order[token] = 0 - elif token in item_tokens and token not in self.item2entity: - item2order[token] = 1 - elif token in self.item2entity or token in self.entity2item: - item2order[token] = 2 + def _reset_ent_remapID(self, field, idmap, id2token, token2id): + self.field2id_token[field] = id2token + self.field2token_id[field] = token2id + for feat in self.field2feats(field): + ftype = self.field2type[field] + if ftype == FeatureType.TOKEN: + old_idx = feat[field].values else: - item2order[token] = 3 - item_ent_token_list = list(self.field2id_token[self.iid_field]) - item_ent_token_list.sort(key=lambda t: item2order[t]) - item_ent_token_list = np.array(item_ent_token_list) - order_list = [item2order[_] for _ in item_ent_token_list] - order_cnt = Counter(order_list) - layered_num = [] - for i in range(4): - layered_num.append(order_cnt[i]) - layered_num = np.cumsum(np.array(layered_num)) - new_id_token = item_ent_token_list[:layered_num[-2]] - new_token_id = {t: i for i, t in enumerate(new_id_token)} - for field in self.rec_level_ent_fields: - self._reset_ent_remapID(field, new_id_token) - self.field2id_token[field] = new_id_token - self.field2token_id[field] = new_token_id - new_id_token = item_ent_token_list[:layered_num[-1]] - new_id_token = [self.item2entity[_] if _ in self.item2entity else _ for _ in new_id_token] - new_token_id = {t: i for i, t in enumerate(new_id_token)} - for field in self.ent_level_ent_fields: - self._reset_ent_remapID(field, item_ent_token_list[:layered_num[-1]]) - self.field2id_token[field] = new_id_token - self.field2token_id[field] = new_token_id - self.field2id_token[self.entity_field] = new_id_token - self.field2token_id[self.entity_field] = new_token_id + old_idx = feat[field].agg(np.concatenate) - def _remap_ID_all(self): - """Firstly, remap entities and items all together. Then sort entity tokens, - then three kinds of entities can be apart away from each other. + new_idx = idmap[old_idx] + + if ftype == FeatureType.TOKEN: + feat[field] = new_idx + else: + split_point = np.cumsum(feat[field].agg(len))[:-1] + feat[field] = np.split(new_idx, split_point) + + def _merge_item_and_entity(self): + """Merge item-id and entity-id into the same id-space. """ - self._remap_entities_by_link() - item_tokens = self._get_rec_item_token() + item_token = self.field2id_token[self.iid_field] + entity_token = self.field2id_token[self.head_entity_field] + item_num = len(item_token) + link_num = len(self.item2entity) + entity_num = len(entity_token) + + # reset item id + item_priority = np.array([token in self.item2entity for token in item_token]) + item_order = np.argsort(item_priority, kind='stable') + item_id_map = np.zeros_like(item_order) + item_id_map[item_order] = np.arange(item_num) + new_item_id2token = item_token[item_order] + new_item_token2id = {t: i for i, t in enumerate(new_item_id2token)} + for field in self.alias['item_id']: + self._reset_ent_remapID(field, item_id_map, new_item_id2token, new_item_token2id) + + # reset entity id + entity_priority = np.array([token != '[PAD]' and token not in self.entity2item for token in entity_token]) + entity_order = np.argsort(entity_priority, kind='stable') + entity_id_map = np.zeros_like(entity_order) + for i in entity_order[1:link_num + 1]: + entity_id_map[i] = new_item_token2id[self.entity2item[entity_token[i]]] + entity_id_map[entity_order[link_num + 1:]] = np.arange(item_num, item_num + entity_num - link_num - 1) + new_entity_id2token = np.concatenate([new_item_id2token, entity_token[entity_order[link_num + 1:]]]) + for i in range(item_num - link_num, item_num): + new_entity_id2token[i] = self.item2entity[new_entity_id2token[i]] + new_entity_token2id = {t: i for i, t in enumerate(new_entity_id2token)} + for field in self.alias['entity_id']: + self._reset_ent_remapID(field, entity_id_map, new_entity_id2token, new_entity_token2id) + self.field2id_token[self.entity_field] = new_entity_id2token + self.field2token_id[self.entity_field] = new_entity_token2id + + def _remap_ID_all(self): super()._remap_ID_all() - self._sort_remaped_entities(item_tokens) + self._merge_item_and_entity() self.field2token_id[self.relation_field]['[UI-Relation]'] = len(self.field2id_token[self.relation_field]) self.field2id_token[self.relation_field] = np.append(self.field2id_token[self.relation_field], '[UI-Relation]') diff --git a/recbole/properties/dataset/ml-100k.yaml b/recbole/properties/dataset/ml-100k.yaml index d6a99b5f7..7c229f814 100644 --- a/recbole/properties/dataset/ml-100k.yaml +++ b/recbole/properties/dataset/ml-100k.yaml @@ -37,7 +37,10 @@ user_inter_num_interval: ~ item_inter_num_interval: ~ # Preprocessing -fields_in_same_space: ~ +alias_of_user_id: ~ +alias_of_item_id: ~ +alias_of_entity_id: ~ +alias_of_relation_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: True diff --git a/recbole/properties/dataset/sample.yaml b/recbole/properties/dataset/sample.yaml index d9869e5e2..181d014d1 100644 --- a/recbole/properties/dataset/sample.yaml +++ b/recbole/properties/dataset/sample.yaml @@ -37,7 +37,10 @@ max_item_inter_num: ~ min_item_inter_num: 0 # Preprocessing -fields_in_same_space: ~ +alias_of_user_id: ~ +alias_of_item_id: ~ +alias_of_entity_id: ~ +alias_of_relation_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: ~ diff --git a/recbole/utils/argument_list.py b/recbole/utils/argument_list.py index 632f442bd..fc5f1ae1f 100644 --- a/recbole/utils/argument_list.py +++ b/recbole/utils/argument_list.py @@ -47,7 +47,7 @@ 'load_col', 'unload_col', 'unused_col', 'additional_feat_suffix', 'max_user_inter_num', 'min_user_inter_num', 'max_item_inter_num', 'min_item_inter_num', 'lowest_val', 'highest_val', 'equal_val', 'not_equal_val', - 'fields_in_same_space', + 'alias_of_user_id', 'alias_of_item_id', 'alias_of_entity_id', 'alias_of_relation_id', 'preload_weight', 'normalize_field', 'normalize_all' ] diff --git a/tests/data/kg_remap_id/kg_remap_id.inter b/tests/data/kg_remap_id/kg_remap_id.inter new file mode 100644 index 000000000..f8ab8a690 --- /dev/null +++ b/tests/data/kg_remap_id/kg_remap_id.inter @@ -0,0 +1,5 @@ +user_id:token item_id:token +ua ia +ub ib +uc ic +ud id \ No newline at end of file diff --git a/tests/data/kg_remap_id/kg_remap_id.kg b/tests/data/kg_remap_id/kg_remap_id.kg new file mode 100644 index 000000000..95f6d4189 --- /dev/null +++ b/tests/data/kg_remap_id/kg_remap_id.kg @@ -0,0 +1,5 @@ +head_id:token tail_id:token relation_id:token +eb ea ra +ec eb rb +ed ec rc +ee ed rd \ No newline at end of file diff --git a/tests/data/kg_remap_id/kg_remap_id.link b/tests/data/kg_remap_id/kg_remap_id.link new file mode 100644 index 000000000..3328eccb5 --- /dev/null +++ b/tests/data/kg_remap_id/kg_remap_id.link @@ -0,0 +1,5 @@ +item_id:token entity_id:token +ib eb +ic ec +id ed +ie ee \ No newline at end of file diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 2844c97e5..a745f36bf 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -342,7 +342,6 @@ def test_remap_id(self): 'dataset': 'remap_id', 'data_path': current_path, 'load_col': None, - 'fields_in_same_space': None, } dataset = new_dataset(config_dict=config_dict) user_list = dataset.token2id('user_id', ['ua', 'ub', 'uc', 'ud']) @@ -358,16 +357,14 @@ def test_remap_id(self): assert (dataset.inter_feat['user_list'][2] == [3, 4, 1]).all() assert (dataset.inter_feat['user_list'][3] == [5]).all() - def test_remap_id_with_fields_in_same_space(self): + def test_remap_id_with_alias(self): config_dict = { 'model': 'BPR', 'dataset': 'remap_id', 'data_path': current_path, 'load_col': None, - 'fields_in_same_space': [ - ['user_id', 'add_user', 'user_list'], - ['item_id', 'add_item'], - ], + 'alias_of_user_id': ['add_user', 'user_list'], + 'alias_of_item_id': ['add_item'], } dataset = new_dataset(config_dict=config_dict) user_list = dataset.token2id('user_id', ['ua', 'ub', 'uc', 'ud', 'ue', 'uf']) @@ -603,5 +600,25 @@ def test_seq_leave_one_out(self): ]).all() +class TestKGDataset: + def test_kg_remap_id(self): + config_dict = { + 'model': 'KGAT', + 'dataset': 'kg_remap_id', + 'data_path': current_path, + 'load_col': None, + } + dataset = new_dataset(config_dict=config_dict) + print(dataset.field2id_token['entity_id']) + item_list = dataset.token2id('item_id', ['ia', 'ib', 'ic', 'id']) + entity_list = dataset.token2id('entity_id', ['eb', 'ec', 'ed', 'ee', 'ea']) + assert (item_list == [1, 2, 3, 4]).all() + assert (entity_list == [2, 3, 4, 5, 6]).all() + assert (dataset.inter_feat['user_id'] == [1, 2, 3, 4]).all() + assert (dataset.inter_feat['item_id'] == [1, 2, 3, 4]).all() + assert (dataset.kg_feat['head_id'] == [2, 3, 4, 5]).all() + assert (dataset.kg_feat['tail_id'] == [6, 2, 3, 4]).all() + + if __name__ == "__main__": pytest.main() diff --git a/tests/model/test_model.yaml b/tests/model/test_model.yaml index 4604e8a65..a32220742 100644 --- a/tests/model/test_model.yaml +++ b/tests/model/test_model.yaml @@ -49,7 +49,10 @@ equal_val: ~ not_equal_val: ~ # Preprocessing -fields_in_same_space: ~ +alias_of_user_id: ~ +alias_of_item_id: ~ +alias_of_entity_id: ~ +alias_of_relation_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: True