From 9b303407d3e0e4245a277434572fd4367c902a78 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 5 Jul 2021 10:16:43 +0800 Subject: [PATCH 1/3] REFACTOR: refactor `remap_all` --- recbole/data/dataset/dataset.py | 113 ++++++----- recbole/data/dataset/kg_dataset.py | 248 ++++++++---------------- recbole/properties/dataset/ml-100k.yaml | 4 +- recbole/properties/dataset/sample.yaml | 4 +- recbole/utils/argument_list.py | 2 +- tests/data/test_dataset.py | 9 +- tests/model/test_model.yaml | 4 +- 7 files changed, 161 insertions(+), 223 deletions(-) diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index d749de7e8..ffe4f7fac 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -105,6 +105,7 @@ def _from_scratch(self): self._get_preset() self._get_field_from_config() self._load_data(self.dataset_name, self.dataset_path) + self._get_alias() self._data_processing() def _get_preset(self): @@ -441,6 +442,28 @@ def _load_feat(self, filepath, source): self.field2seqlen[field] = max(map(len, df[field].values)) return df + def _get_alias(self): + self.alias_of_user_id = self.config['alias_of_user_id'] or [] + self.alias_of_user_id.append(self.uid_field) + self.alias_of_user_id = set(self.alias_of_user_id) + self.alias_of_item_id = self.config['alias_of_item_id'] or [] + self.alias_of_item_id.append(self.iid_field) + self.alias_of_item_id = set(self.alias_of_item_id) + + if self.alias_of_user_id & self.alias_of_item_id: + raise ValueError(f'`alias_of_user_id` and `alias_of_item_id` ' + f'should not have the same field {list(self.alias_of_user_id & self.alias_of_item_id)}.') + + token_like_fields = set(self.token_like_fields) + if self.alias_of_user_id - token_like_fields: + raise ValueError(f'`alias_of_user_id` should not contain ' + f'non-token-like field {list(self.alias_of_user_id - token_like_fields)}.') + if self.alias_of_item_id - token_like_fields: + raise ValueError(f'`alias_of_item_id` should not contain ' + f'non-token-like field {list(self.alias_of_item_id - token_like_fields)}.') + + self._rest_fields = token_like_fields - self.alias_of_user_id - self.alias_of_item_id + def _user_item_feat_preparation(self): """Sort :attr:`user_feat` and :attr:`item_feat` by ``user_id`` or ``item_id``. Missing values will be filled later. @@ -860,40 +883,6 @@ def _set_label_by_threshold(self): raise ValueError(f'Field [{field}] not in inter_feat.') self._del_col(self.inter_feat, field) - def _get_fields_in_same_space(self): - """Parsing ``config['fields_in_same_space']``. See :doc:`../user_guide/data/data_args` for detail arg setting. - - Note: - - Each field can only exist ONCE in ``config['fields_in_same_space']``. - - user_id and item_id can not exist in ``config['fields_in_same_space']``. - - only token-like fields can exist in ``config['fields_in_same_space']``. - """ - fields_in_same_space = self.config['fields_in_same_space'] or [] - fields_in_same_space = [set(_) for _ in fields_in_same_space] - additional = [] - token_like_fields = self.token_like_fields - for field in token_like_fields: - count = 0 - for field_set in fields_in_same_space: - if field in field_set: - count += 1 - if count == 0: - additional.append({field}) - elif count == 1: - continue - else: - raise ValueError(f'Field [{field}] occurred in `fields_in_same_space` more than one time.') - - for field_set in fields_in_same_space: - if self.uid_field in field_set and self.iid_field in field_set: - raise ValueError('uid_field and iid_field can\'t in the same ID space') - for field in field_set: - if field not in token_like_fields: - raise ValueError(f'Field [{field}] is not a token-like field.') - - fields_in_same_space.extend(additional) - return fields_in_same_space - def _get_remap_list(self, field_set): """Transfer set of fields in the same remapping space into remap list. @@ -912,29 +901,33 @@ def _get_remap_list(self, field_set): They will be concatenated in order, and remapped together. """ + remap_list = [] - for field, feat in zip([self.uid_field, self.iid_field], [self.user_feat, self.item_feat]): - if field in field_set: - field_set.remove(field) - remap_list.append((self.inter_feat, field, FeatureType.TOKEN)) - if feat is not None: - remap_list.append((feat, field, FeatureType.TOKEN)) - for field in field_set: - source = self.field2source[field] - if isinstance(source, FeatureSource): - source = source.value - feat = getattr(self, f'{source}_feat') + field_list = sorted(field_set, key=lambda f: f != self.uid_field and f != self.iid_field) + for field in field_list: ftype = self.field2type[field] - remap_list.append((feat, field, ftype)) + for feat in self.field2feats(field): + remap_list.append((feat, field, ftype)) return remap_list def _remap_ID_all(self): - """Get ``config['fields_in_same_space']`` firstly, and remap each. """ - fields_in_same_space = self._get_fields_in_same_space() - self.logger.debug(set_color('fields_in_same_space', 'blue') + f': {fields_in_same_space}') - for field_set in fields_in_same_space: - remap_list = self._get_remap_list(field_set) + """ + self._remap_user() + self._remap_item() + self._remap_rest() + + def _remap_user(self): + user_remap_list = self._get_remap_list(self.alias_of_user_id) + self._remap(user_remap_list) + + def _remap_item(self): + item_remap_list = self._get_remap_list(self.alias_of_item_id) + self._remap(item_remap_list) + + def _remap_rest(self): + for field in self._rest_fields: + remap_list = self._get_remap_list({field}) self._remap(remap_list) def _concat_remaped_tokens(self, remap_list): @@ -1087,6 +1080,24 @@ def copy_field_property(self, dest_field, source_field): self.field2source[dest_field] = self.field2source[source_field] self.field2seqlen[dest_field] = self.field2seqlen[source_field] + def field2feats(self, field): + if field not in self.field2source: + raise ValueError(f'Field [{field}] not defined in dataset.') + if field == self.uid_field: + feats = [self.inter_feat] + if self.user_feat is not None: + feats.append(self.user_feat) + elif field == self.iid_field: + feats = [self.inter_feat] + if self.item_feat is not None: + feats.append(self.item_feat) + else: + source = self.field2source[field] + if not isinstance(source, str): + source = source.value + feats = [getattr(self, f'{source}_feat')] + return feats + def token2id(self, field, tokens): """Map external tokens to internal ids. diff --git a/recbole/data/dataset/kg_dataset.py b/recbole/data/dataset/kg_dataset.py index 5bf4ab1c9..e1ba28ab0 100644 --- a/recbole/data/dataset/kg_dataset.py +++ b/recbole/data/dataset/kg_dataset.py @@ -59,7 +59,7 @@ class KnowledgeBasedDataset(Dataset): Note: :attr:`entity_field` doesn't exist exactly. It's only a symbol, - representing entity features. E.g. it can be written into ``config['fields_in_same_space']``. + representing entity features. ``[UI-Relation]`` is a special relation token. """ @@ -67,10 +67,6 @@ class KnowledgeBasedDataset(Dataset): def __init__(self, config): super().__init__(config) - def _get_preset(self): - super()._get_preset() - self.field2ent_level = {} - def _get_field_from_config(self): super()._get_field_from_config() @@ -84,10 +80,6 @@ def _get_field_from_config(self): self.logger.debug(set_color('relation_field', 'blue') + f': {self.relation_field}') self.logger.debug(set_color('entity_field', 'blue') + f': {self.entity_field}') - def _data_processing(self): - self._set_field2ent_level() - super()._data_processing() - def _data_filtering(self): super()._data_filtering() self._filter_link() @@ -157,9 +149,6 @@ def _build_feat_name_list(self): feat_name_list.append('kg_feat') return feat_name_list - def _restore_saved_dataset(self, saved_dataset): - raise NotImplementedError() - def save(self, filepath): raise NotImplementedError() @@ -197,194 +186,129 @@ def _check_link(self, link): assert self.entity_field in link, link_warn_message.format(self.entity_field) assert self.iid_field in link, link_warn_message.format(self.iid_field) - def _get_fields_in_same_space(self): - """Parsing ``config['fields_in_same_space']``. See :doc:`../user_guide/data/data_args` for detail arg setting. - - Note: - - Each field can only exist ONCE in ``config['fields_in_same_space']``. - - user_id and item_id can not exist in ``config['fields_in_same_space']``. - - only token-like fields can exist in ``config['fields_in_same_space']``. - - ``head_entity_id`` and ``target_entity_id`` should be remapped with ``item_id``. - """ - fields_in_same_space = super()._get_fields_in_same_space() - fields_in_same_space = [_ for _ in fields_in_same_space if not self._contain_ent_field(_)] - ent_fields = self._get_ent_fields_in_same_space() - for field_set in fields_in_same_space: - if self.iid_field in field_set: - field_set.update(ent_fields) - return fields_in_same_space - - def _contain_ent_field(self, field_set): - """Return True if ``field_set`` contains entity fields. - """ - flag = False - flag |= self.head_entity_field in field_set - flag |= self.tail_entity_field in field_set - flag |= self.entity_field in field_set - return flag - - def _get_ent_fields_in_same_space(self): - """Return ``field_set`` that should be remapped together with entities. - """ - fields_in_same_space = super()._get_fields_in_same_space() - - ent_fields = {self.head_entity_field, self.tail_entity_field} - for field_set in fields_in_same_space: - if self._contain_ent_field(field_set): - field_set = self._remove_ent_field(field_set) - ent_fields.update(field_set) - self.logger.debug(set_color('ent_fields', 'blue') + f': {fields_in_same_space}') - return ent_fields - - def _remove_ent_field(self, field_set): - """Delete entity fields from ``field_set``. - """ - for field in [self.head_entity_field, self.tail_entity_field, self.entity_field]: - if field in field_set: - field_set.remove(field) - return field_set - - def _set_field2ent_level(self): - """For fields that remapped together with ``item_id``, - set their levels as ``rec``, otherwise as ``ent``. - """ - fields_in_same_space = self._get_fields_in_same_space() - for field_set in fields_in_same_space: - if self.iid_field in field_set: - for field in field_set: - self.field2ent_level[field] = 'rec' - ent_fields = self._get_ent_fields_in_same_space() - for field in ent_fields: - self.field2ent_level[field] = 'ent' - - def _fields_by_ent_level(self, ent_level): - """Given ``ent_level``, return all the field name of this level. - """ - ret = [] - for field in self.field2ent_level: - if self.field2ent_level[field] == ent_level: - ret.append(field) - return ret - - @property - def rec_level_ent_fields(self): - """Get entity fields remapped together with ``item_id``. - - Returns: - list: List of field names. + def _get_alias(self): + """Add :attr:`alias_of_entity_id` and update :attr:`_rest_fields` """ - return self._fields_by_ent_level('rec') + super()._get_alias() + self.alias_of_entity_id = self.config['alias_of_entity_id'] or [] + self.alias_of_entity_id.extend([self.head_entity_field, self.tail_entity_field]) + self.alias_of_entity_id = set(self.alias_of_entity_id) - @property - def ent_level_ent_fields(self): - """Get entity fields remapped together with ``entity_id``. + if self.alias_of_entity_id & self.alias_of_user_id: + raise ValueError(f'`alias_of_entity_id` and `alias_of_user_id` ' + f'should not have the same field {list(self.alias_of_entity_id & self.alias_of_user_id)}.') + if self.alias_of_entity_id & self.alias_of_item_id: + raise ValueError(f'`alias_of_entity_id` and `alias_of_item_id` ' + f'should not have the same field {list(self.alias_of_entity_id & self.alias_of_item_id)}.') - Returns: - list: List of field names. - """ - return self._fields_by_ent_level('ent') + token_like_fields = set(self.token_like_fields) + if self.alias_of_entity_id - token_like_fields: + raise ValueError(f'`alias_of_entity_id` should not contain ' + f'non-token-like field {list(self.alias_of_entity_id - token_like_fields)}.') - def _remap_entities_by_link(self): - """Map entity tokens from fields in ``ent`` level - to item tokens according to ``.link``. - """ - for ent_field in self.ent_level_ent_fields: - source = self.field2source[ent_field] - if not isinstance(source, str): - source = source.value - feat = getattr(self, f'{source}_feat') - entity_list = feat[ent_field].values - for i, entity_id in enumerate(entity_list): - if entity_id in self.entity2item: - entity_list[i] = self.entity2item[entity_id] - feat[ent_field] = entity_list + self._rest_fields = self._rest_fields - self.alias_of_entity_id - {self.entity_field} def _get_rec_item_token(self): """Get set of entity tokens from fields in ``rec`` level. """ - field_set = set(self.rec_level_ent_fields) - remap_list = self._get_remap_list(field_set) + remap_list = self._get_remap_list(self.alias_of_item_id) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) def _get_entity_token(self): """Get set of entity tokens from fields in ``ent`` level. """ - field_set = set(self.ent_level_ent_fields) - remap_list = self._get_remap_list(field_set) + remap_list = self._get_remap_list(self.alias_of_entity_id) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) - def _reset_ent_remapID(self, field, new_id_token): - token2id = {} - for i, token in enumerate(new_id_token): - token2id[token] = i - idmap = {} - for i, token in enumerate(self.field2id_token[field]): - if token not in token2id: - continue - new_idx = token2id[token] - idmap[i] = new_idx - source = self.field2source[field] - if not isinstance(source, str): - source = source.value - if source == 'item_id': - feats = [self.inter_feat] - if self.item_feat is not None: - feats.append(self.item_feat) - else: - feats = [getattr(self, f'{source}_feat')] - for feat in feats: - old_idx = feat[field].values - new_idx = np.array([idmap[_] for _ in old_idx]) - feat[field] = new_idx + def _remap_entities_by_link(self): + """Map entity tokens from fields in ``ent`` level + to item tokens according to ``.link``. + """ + for ent_field in self.alias_of_entity_id: + feat = self.field2feats(ent_field)[0] + ftype = self.field2type[ent_field] + if ftype == FeatureType.TOKEN: + entity_list = feat[ent_field].values + else: + entity_list = feat[ent_field].agg(np.concatenate) + + for i, entity_id in enumerate(entity_list): + if entity_id in self.entity2item: + entity_list[i] = self.entity2item[entity_id] + + if ftype == FeatureType.TOKEN: + feat[ent_field] = entity_list + else: + split_point = np.cumsum(feat[ent_field].agg(len))[:-1] + feat[ent_field] = np.split(entity_list, split_point) + + def _reset_ent_remapID(self, field, idmap): + for feat in self.field2feats(field): + ftype = self.field2type[field] + if ftype == FeatureType.TOKEN: + old_idx = feat[field].values + else: + old_idx = feat[field].agg(np.concatenate) + + new_idx = idmap[old_idx] + + if ftype == FeatureType.TOKEN: + feat[field] = new_idx + else: + split_point = np.cumsum(feat[field].agg(len))[:-1] + feat[field] = np.split(new_idx, split_point) def _sort_remaped_entities(self, item_tokens): - item2order = {} + level_list = [] for token in self.field2id_token[self.iid_field]: if token == '[PAD]': - item2order[token] = 0 + level_list.append(0) elif token in item_tokens and token not in self.item2entity: - item2order[token] = 1 + level_list.append(1) elif token in self.item2entity or token in self.entity2item: - item2order[token] = 2 + level_list.append(2) else: - item2order[token] = 3 - item_ent_token_list = list(self.field2id_token[self.iid_field]) - item_ent_token_list.sort(key=lambda t: item2order[t]) - item_ent_token_list = np.array(item_ent_token_list) - order_list = [item2order[_] for _ in item_ent_token_list] - order_cnt = Counter(order_list) - layered_num = [] - for i in range(4): - layered_num.append(order_cnt[i]) - layered_num = np.cumsum(np.array(layered_num)) - new_id_token = item_ent_token_list[:layered_num[-2]] + level_list.append(3) + level_list = np.array(level_list) + order_list = np.argsort(level_list) + token_list = self.field2id_token[self.iid_field][order_list] + idmap = np.zeros_like(order_list) + idmap[order_list] = np.arange(len(order_list)) + item_num = np.sum(level_list < 3) + + new_id_token = token_list[:item_num] new_token_id = {t: i for i, t in enumerate(new_id_token)} - for field in self.rec_level_ent_fields: - self._reset_ent_remapID(field, new_id_token) + for field in self.alias_of_item_id: + self._reset_ent_remapID(field, idmap) self.field2id_token[field] = new_id_token self.field2token_id[field] = new_token_id - new_id_token = item_ent_token_list[:layered_num[-1]] - new_id_token = [self.item2entity[_] if _ in self.item2entity else _ for _ in new_id_token] + + new_id_token = np.array([self.item2entity[_] if _ in self.item2entity else _ for _ in token_list]) new_token_id = {t: i for i, t in enumerate(new_id_token)} - for field in self.ent_level_ent_fields: - self._reset_ent_remapID(field, item_ent_token_list[:layered_num[-1]]) + for field in self.alias_of_entity_id: + self._reset_ent_remapID(field, idmap) self.field2id_token[field] = new_id_token self.field2token_id[field] = new_token_id self.field2id_token[self.entity_field] = new_id_token self.field2token_id[self.entity_field] = new_token_id def _remap_ID_all(self): - """Firstly, remap entities and items all together. Then sort entity tokens, - then three kinds of entities can be apart away from each other. + super()._remap_ID_all() + self.field2token_id[self.relation_field]['[UI-Relation]'] = len(self.field2id_token[self.relation_field]) + self.field2id_token[self.relation_field] = np.append(self.field2id_token[self.relation_field], '[UI-Relation]') + + def _remap_item(self): + """ """ self._remap_entities_by_link() item_tokens = self._get_rec_item_token() - super()._remap_ID_all() + + item_remap_list = self._get_remap_list(self.alias_of_item_id | self.alias_of_entity_id) + self._remap(item_remap_list) + self._sort_remaped_entities(item_tokens) - self.field2token_id[self.relation_field]['[UI-Relation]'] = len(self.field2id_token[self.relation_field]) - self.field2id_token[self.relation_field] = np.append(self.field2id_token[self.relation_field], '[UI-Relation]') @property def relation_num(self): diff --git a/recbole/properties/dataset/ml-100k.yaml b/recbole/properties/dataset/ml-100k.yaml index d6a99b5f7..796bb5e24 100644 --- a/recbole/properties/dataset/ml-100k.yaml +++ b/recbole/properties/dataset/ml-100k.yaml @@ -37,7 +37,9 @@ user_inter_num_interval: ~ item_inter_num_interval: ~ # Preprocessing -fields_in_same_space: ~ +alias_of_user_id: ~ +alias_of_item_id: ~ +alias_of_entity_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: True diff --git a/recbole/properties/dataset/sample.yaml b/recbole/properties/dataset/sample.yaml index d9869e5e2..156942110 100644 --- a/recbole/properties/dataset/sample.yaml +++ b/recbole/properties/dataset/sample.yaml @@ -37,7 +37,9 @@ max_item_inter_num: ~ min_item_inter_num: 0 # Preprocessing -fields_in_same_space: ~ +alias_of_user_id: ~ +alias_of_item_id: ~ +alias_of_entity_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: ~ diff --git a/recbole/utils/argument_list.py b/recbole/utils/argument_list.py index 632f442bd..43e8db305 100644 --- a/recbole/utils/argument_list.py +++ b/recbole/utils/argument_list.py @@ -47,7 +47,7 @@ 'load_col', 'unload_col', 'unused_col', 'additional_feat_suffix', 'max_user_inter_num', 'min_user_inter_num', 'max_item_inter_num', 'min_item_inter_num', 'lowest_val', 'highest_val', 'equal_val', 'not_equal_val', - 'fields_in_same_space', + 'alias_of_user_id', 'alias_of_item_id', 'alias_of_entity_id', 'preload_weight', 'normalize_field', 'normalize_all' ] diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index fa4eb1af6..223519134 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -346,7 +346,6 @@ def test_remap_id(self): 'dataset': 'remap_id', 'data_path': current_path, 'load_col': None, - 'fields_in_same_space': None, } dataset = new_dataset(config_dict=config_dict) user_list = dataset.token2id('user_id', ['ua', 'ub', 'uc', 'ud']) @@ -362,16 +361,14 @@ def test_remap_id(self): assert (dataset.inter_feat['user_list'][2] == [3, 4, 1]).all() assert (dataset.inter_feat['user_list'][3] == [5]).all() - def test_remap_id_with_fields_in_same_space(self): + def test_remap_id_with_alias(self): config_dict = { 'model': 'BPR', 'dataset': 'remap_id', 'data_path': current_path, 'load_col': None, - 'fields_in_same_space': [ - ['user_id', 'add_user', 'user_list'], - ['item_id', 'add_item'], - ], + 'alias_of_user_id': ['add_user', 'user_list'], + 'alias_of_item_id': ['add_item'], } dataset = new_dataset(config_dict=config_dict) user_list = dataset.token2id('user_id', ['ua', 'ub', 'uc', 'ud', 'ue', 'uf']) diff --git a/tests/model/test_model.yaml b/tests/model/test_model.yaml index 4604e8a65..3f366fe2c 100644 --- a/tests/model/test_model.yaml +++ b/tests/model/test_model.yaml @@ -49,7 +49,9 @@ equal_val: ~ not_equal_val: ~ # Preprocessing -fields_in_same_space: ~ +alias_of_user_id: ~ +alias_of_item_id: ~ +alias_of_entity_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: True From 289061accb8b04dc8274f6ca37f917d165e97eb3 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Mon, 5 Jul 2021 23:10:49 +0800 Subject: [PATCH 2/3] REFACTOR: add doc and unittest to `remap_all` --- recbole/data/dataset/dataset.py | 52 +++++---- recbole/data/dataset/kg_dataset.py | 128 ++++++++++------------- tests/data/kg_remap_id/kg_remap_id.inter | 5 + tests/data/kg_remap_id/kg_remap_id.kg | 5 + tests/data/kg_remap_id/kg_remap_id.link | 5 + tests/data/test_dataset.py | 20 ++++ 6 files changed, 122 insertions(+), 93 deletions(-) create mode 100644 tests/data/kg_remap_id/kg_remap_id.inter create mode 100644 tests/data/kg_remap_id/kg_remap_id.kg create mode 100644 tests/data/kg_remap_id/kg_remap_id.link diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index ffe4f7fac..edd057047 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -443,26 +443,35 @@ def _load_feat(self, filepath, source): return df def _get_alias(self): + """Set :attr:`alias_of_user_id` and :attr:`alias_of_item_id`. + """ self.alias_of_user_id = self.config['alias_of_user_id'] or [] - self.alias_of_user_id.append(self.uid_field) - self.alias_of_user_id = set(self.alias_of_user_id) + self.alias_of_user_id = np.array([self.uid_field] + self.alias_of_user_id) + _, idx = np.unique(self.alias_of_user_id, return_index=True) + self.alias_of_user_id = self.alias_of_user_id[np.sort(idx)] + self.alias_of_item_id = self.config['alias_of_item_id'] or [] - self.alias_of_item_id.append(self.iid_field) - self.alias_of_item_id = set(self.alias_of_item_id) + self.alias_of_item_id = np.array([self.iid_field] + self.alias_of_item_id) + _, idx = np.unique(self.alias_of_item_id, return_index=True) + self.alias_of_item_id = self.alias_of_item_id[np.sort(idx)] - if self.alias_of_user_id & self.alias_of_item_id: + intersect = np.intersect1d(self.alias_of_user_id, self.alias_of_item_id, assume_unique=True) + if len(intersect) > 0: raise ValueError(f'`alias_of_user_id` and `alias_of_item_id` ' - f'should not have the same field {list(self.alias_of_user_id & self.alias_of_item_id)}.') + f'should not have the same field {list(intersect)}.') - token_like_fields = set(self.token_like_fields) - if self.alias_of_user_id - token_like_fields: + token_like_fields = self.token_like_fields + user_isin = np.isin(self.alias_of_user_id, token_like_fields, assume_unique=True) + if user_isin.all() is False: raise ValueError(f'`alias_of_user_id` should not contain ' - f'non-token-like field {list(self.alias_of_user_id - token_like_fields)}.') - if self.alias_of_item_id - token_like_fields: + f'non-token-like field {list(self.alias_of_user_id[~user_isin])}.') + item_isin = np.isin(self.alias_of_item_id, token_like_fields, assume_unique=True) + if item_isin.all() is False: raise ValueError(f'`alias_of_item_id` should not contain ' - f'non-token-like field {list(self.alias_of_item_id - token_like_fields)}.') + f'non-token-like field {list(self.alias_of_item_id[~item_isin])}.') - self._rest_fields = token_like_fields - self.alias_of_user_id - self.alias_of_item_id + self._rest_fields = np.setdiff1d(token_like_fields, self.alias_of_user_id, assume_unique=True) + self._rest_fields = np.setdiff1d(self._rest_fields, self.alias_of_item_id, assume_unique=True) def _user_item_feat_preparation(self): """Sort :attr:`user_feat` and :attr:`item_feat` by ``user_id`` or ``item_id``. @@ -883,7 +892,7 @@ def _set_label_by_threshold(self): raise ValueError(f'Field [{field}] not in inter_feat.') self._del_col(self.inter_feat, field) - def _get_remap_list(self, field_set): + def _get_remap_list(self, field_list): """Transfer set of fields in the same remapping space into remap list. If ``uid_field`` or ``iid_field`` in ``field_set``, @@ -891,7 +900,7 @@ def _get_remap_list(self, field_set): then field in :attr:`user_feat` or :attr:`item_feat` will be remapped next, finally others. Args: - field_set (set): Set of fields in the same remapping space + field_list (numpy.ndarray): List of fields in the same remapping space. Returns: list: @@ -903,7 +912,6 @@ def _get_remap_list(self, field_set): """ remap_list = [] - field_list = sorted(field_set, key=lambda f: f != self.uid_field and f != self.iid_field) for field in field_list: ftype = self.field2type[field] for feat in self.field2feats(field): @@ -911,23 +919,25 @@ def _get_remap_list(self, field_set): return remap_list def _remap_ID_all(self): + """Remap all token-like fields. """ - """ - self._remap_user() - self._remap_item() + self._remap_alias() self._remap_rest() - def _remap_user(self): + def _remap_alias(self): + """Remap :attr:`alias_of_user_id` and :attr:`alias_of_item_id`. + """ user_remap_list = self._get_remap_list(self.alias_of_user_id) self._remap(user_remap_list) - def _remap_item(self): item_remap_list = self._get_remap_list(self.alias_of_item_id) self._remap(item_remap_list) def _remap_rest(self): + """Remap other token-like fields. + """ for field in self._rest_fields: - remap_list = self._get_remap_list({field}) + remap_list = self._get_remap_list(np.array([field])) self._remap(remap_list) def _concat_remaped_tokens(self, remap_list): diff --git a/recbole/data/dataset/kg_dataset.py b/recbole/data/dataset/kg_dataset.py index e1ba28ab0..5298dc608 100644 --- a/recbole/data/dataset/kg_dataset.py +++ b/recbole/data/dataset/kg_dataset.py @@ -191,22 +191,27 @@ def _get_alias(self): """ super()._get_alias() self.alias_of_entity_id = self.config['alias_of_entity_id'] or [] - self.alias_of_entity_id.extend([self.head_entity_field, self.tail_entity_field]) - self.alias_of_entity_id = set(self.alias_of_entity_id) + self.alias_of_entity_id = np.array([self.head_entity_field, self.tail_entity_field] + self.alias_of_entity_id) + _, idx = np.unique(self.alias_of_entity_id, return_index=True) + self.alias_of_entity_id = self.alias_of_entity_id[np.sort(idx)] - if self.alias_of_entity_id & self.alias_of_user_id: + intersect = np.intersect1d(self.alias_of_entity_id, self.alias_of_user_id, assume_unique=True) + if len(intersect) > 0: raise ValueError(f'`alias_of_entity_id` and `alias_of_user_id` ' - f'should not have the same field {list(self.alias_of_entity_id & self.alias_of_user_id)}.') - if self.alias_of_entity_id & self.alias_of_item_id: + f'should not have the same field {list(intersect)}.') + intersect = np.intersect1d(self.alias_of_entity_id, self.alias_of_item_id, assume_unique=True) + if len(intersect) > 0: raise ValueError(f'`alias_of_entity_id` and `alias_of_item_id` ' - f'should not have the same field {list(self.alias_of_entity_id & self.alias_of_item_id)}.') + f'should not have the same field {list(intersect)}.') - token_like_fields = set(self.token_like_fields) - if self.alias_of_entity_id - token_like_fields: + token_like_fields = self.token_like_fields + entity_isin = np.isin(self.alias_of_entity_id, token_like_fields, assume_unique=True) + if entity_isin.all() is False: raise ValueError(f'`alias_of_entity_id` should not contain ' - f'non-token-like field {list(self.alias_of_entity_id - token_like_fields)}.') + f'non-token-like field {list(self.alias_of_entity_id[~entity_isin])}.') - self._rest_fields = self._rest_fields - self.alias_of_entity_id - {self.entity_field} + self._rest_fields = np.setdiff1d(self._rest_fields, self.alias_of_entity_id, assume_unique=True) + self._rest_fields = np.setdiff1d(self._rest_fields, [self.entity_field], assume_unique=True) def _get_rec_item_token(self): """Get set of entity tokens from fields in ``rec`` level. @@ -222,29 +227,9 @@ def _get_entity_token(self): tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) - def _remap_entities_by_link(self): - """Map entity tokens from fields in ``ent`` level - to item tokens according to ``.link``. - """ - for ent_field in self.alias_of_entity_id: - feat = self.field2feats(ent_field)[0] - ftype = self.field2type[ent_field] - if ftype == FeatureType.TOKEN: - entity_list = feat[ent_field].values - else: - entity_list = feat[ent_field].agg(np.concatenate) - - for i, entity_id in enumerate(entity_list): - if entity_id in self.entity2item: - entity_list[i] = self.entity2item[entity_id] - - if ftype == FeatureType.TOKEN: - feat[ent_field] = entity_list - else: - split_point = np.cumsum(feat[ent_field].agg(len))[:-1] - feat[ent_field] = np.split(entity_list, split_point) - - def _reset_ent_remapID(self, field, idmap): + def _reset_ent_remapID(self, field, idmap, id2token, token2id): + self.field2id_token[field] = id2token + self.field2token_id[field] = token2id for feat in self.field2feats(field): ftype = self.field2type[field] if ftype == FeatureType.TOKEN: @@ -260,55 +245,54 @@ def _reset_ent_remapID(self, field, idmap): split_point = np.cumsum(feat[field].agg(len))[:-1] feat[field] = np.split(new_idx, split_point) - def _sort_remaped_entities(self, item_tokens): - level_list = [] - for token in self.field2id_token[self.iid_field]: - if token == '[PAD]': - level_list.append(0) - elif token in item_tokens and token not in self.item2entity: - level_list.append(1) - elif token in self.item2entity or token in self.entity2item: - level_list.append(2) - else: - level_list.append(3) - level_list = np.array(level_list) - order_list = np.argsort(level_list) - token_list = self.field2id_token[self.iid_field][order_list] - idmap = np.zeros_like(order_list) - idmap[order_list] = np.arange(len(order_list)) - item_num = np.sum(level_list < 3) - - new_id_token = token_list[:item_num] - new_token_id = {t: i for i, t in enumerate(new_id_token)} + def _merge_item_and_entity(self): + """Merge item-id and entity-id into the same id-space. + """ + item_token = self.field2id_token[self.iid_field] + entity_token = self.field2id_token[self.head_entity_field] + item_num = len(item_token) + link_num = len(self.item2entity) + entity_num = len(entity_token) + + # reset item id + item_priority = np.array([token in self.item2entity for token in item_token]) + item_order = np.argsort(item_priority, kind='stable') + item_id_map = np.zeros_like(item_order) + item_id_map[item_order] = np.arange(item_num) + new_item_id2token = item_token[item_order] + new_item_token2id = {t: i for i, t in enumerate(new_item_id2token)} for field in self.alias_of_item_id: - self._reset_ent_remapID(field, idmap) - self.field2id_token[field] = new_id_token - self.field2token_id[field] = new_token_id - - new_id_token = np.array([self.item2entity[_] if _ in self.item2entity else _ for _ in token_list]) - new_token_id = {t: i for i, t in enumerate(new_id_token)} + self._reset_ent_remapID(field, item_id_map, new_item_id2token, new_item_token2id) + + # reset entity id + entity_priority = np.array([token != '[PAD]' and token not in self.entity2item for token in entity_token]) + entity_order = np.argsort(entity_priority, kind='stable') + entity_id_map = np.zeros_like(entity_order) + for i in entity_order[1:link_num + 1]: + entity_id_map[i] = new_item_token2id[self.entity2item[entity_token[i]]] + entity_id_map[entity_order[link_num + 1:]] = np.arange(item_num, item_num + entity_num - link_num - 1) + new_entity_id2token = np.concatenate([new_item_id2token, entity_token[entity_order[link_num + 1:]]]) + for i in range(item_num - link_num, item_num): + new_entity_id2token[i] = self.item2entity[new_entity_id2token[i]] + new_entity_token2id = {t: i for i, t in enumerate(new_entity_id2token)} for field in self.alias_of_entity_id: - self._reset_ent_remapID(field, idmap) - self.field2id_token[field] = new_id_token - self.field2token_id[field] = new_token_id - self.field2id_token[self.entity_field] = new_id_token - self.field2token_id[self.entity_field] = new_token_id + self._reset_ent_remapID(field, entity_id_map, new_entity_id2token, new_entity_token2id) + self.field2id_token[self.entity_field] = new_entity_id2token + self.field2token_id[self.entity_field] = new_entity_token2id def _remap_ID_all(self): super()._remap_ID_all() self.field2token_id[self.relation_field]['[UI-Relation]'] = len(self.field2id_token[self.relation_field]) self.field2id_token[self.relation_field] = np.append(self.field2id_token[self.relation_field], '[UI-Relation]') - def _remap_item(self): + def _remap_alias(self): + """Remap :attr:`alias_of_entity_id` additionally. """ - """ - self._remap_entities_by_link() - item_tokens = self._get_rec_item_token() - - item_remap_list = self._get_remap_list(self.alias_of_item_id | self.alias_of_entity_id) - self._remap(item_remap_list) + super()._remap_alias() + entity_remap_list = self._get_remap_list(self.alias_of_entity_id) + self._remap(entity_remap_list) - self._sort_remaped_entities(item_tokens) + self._merge_item_and_entity() @property def relation_num(self): diff --git a/tests/data/kg_remap_id/kg_remap_id.inter b/tests/data/kg_remap_id/kg_remap_id.inter new file mode 100644 index 000000000..f8ab8a690 --- /dev/null +++ b/tests/data/kg_remap_id/kg_remap_id.inter @@ -0,0 +1,5 @@ +user_id:token item_id:token +ua ia +ub ib +uc ic +ud id \ No newline at end of file diff --git a/tests/data/kg_remap_id/kg_remap_id.kg b/tests/data/kg_remap_id/kg_remap_id.kg new file mode 100644 index 000000000..95f6d4189 --- /dev/null +++ b/tests/data/kg_remap_id/kg_remap_id.kg @@ -0,0 +1,5 @@ +head_id:token tail_id:token relation_id:token +eb ea ra +ec eb rb +ed ec rc +ee ed rd \ No newline at end of file diff --git a/tests/data/kg_remap_id/kg_remap_id.link b/tests/data/kg_remap_id/kg_remap_id.link new file mode 100644 index 000000000..3328eccb5 --- /dev/null +++ b/tests/data/kg_remap_id/kg_remap_id.link @@ -0,0 +1,5 @@ +item_id:token entity_id:token +ib eb +ic ec +id ed +ie ee \ No newline at end of file diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 223519134..433ebf3a0 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -611,5 +611,25 @@ def test_seq_leave_one_out(self): ]).all() +class TestKGDataset: + def test_kg_remap_id(self): + config_dict = { + 'model': 'KGAT', + 'dataset': 'kg_remap_id', + 'data_path': current_path, + 'load_col': None, + } + dataset = new_dataset(config_dict=config_dict) + print(dataset.field2id_token['entity_id']) + item_list = dataset.token2id('item_id', ['ia', 'ib', 'ic', 'id']) + entity_list = dataset.token2id('entity_id', ['eb', 'ec', 'ed', 'ee', 'ea']) + assert (item_list == [1, 2, 3, 4]).all() + assert (entity_list == [2, 3, 4, 5, 6]).all() + assert (dataset.inter_feat['user_id'] == [1, 2, 3, 4]).all() + assert (dataset.inter_feat['item_id'] == [1, 2, 3, 4]).all() + assert (dataset.kg_feat['head_id'] == [2, 3, 4, 5]).all() + assert (dataset.kg_feat['tail_id'] == [6, 2, 3, 4]).all() + + if __name__ == "__main__": pytest.main() From 5fb7dbbb2e2b502291dcd3408dbe6a6c33965da6 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 6 Jul 2021 18:02:52 +0800 Subject: [PATCH 3/3] REFACTOR: add `alias_of_relation_id` and update doc. --- docs/source/user_guide/data/atomic_files.rst | 2 +- docs/source/user_guide/data/data_args.rst | 5 +- .../user_guide/model/sequential/gru4reckg.rst | 4 +- .../user_guide/model/sequential/ksr.rst | 6 +- .../usage/load_pretrained_embedding.rst | 8 +- recbole/data/dataset/dataset.py | 77 ++++++++----------- recbole/data/dataset/kg_dataset.py | 49 +++--------- recbole/properties/dataset/ml-100k.yaml | 1 + recbole/properties/dataset/sample.yaml | 1 + recbole/utils/argument_list.py | 2 +- tests/model/test_model.yaml | 1 + 11 files changed, 60 insertions(+), 96 deletions(-) diff --git a/docs/source/user_guide/data/atomic_files.rst b/docs/source/user_guide/data/atomic_files.rst index 9e74dd4d7..ebac34ce3 100644 --- a/docs/source/user_guide/data/atomic_files.rst +++ b/docs/source/user_guide/data/atomic_files.rst @@ -141,4 +141,4 @@ For example, if you want to map the tokens of ``ent_id`` into the same space of # inter/user/item/...: As usual ent: [ent_id, ent_emb] - fields_in_same_space: [[ent_id, entity_id]] + alias_of_entity_id: [ent_id] diff --git a/docs/source/user_guide/data/data_args.rst b/docs/source/user_guide/data/data_args.rst index 9c310f443..97ce58835 100644 --- a/docs/source/user_guide/data/data_args.rst +++ b/docs/source/user_guide/data/data_args.rst @@ -88,7 +88,10 @@ Filter by number of interactions Preprocessing ----------------- -- ``fields_in_same_space (list)`` : List of spaces. Space is a list of string similar to the fields' names. The fields in the same space will be remapped into the same index system. Note that if you want to make some fields remapped in the same space with entities, then just set ``fields_in_same_space = [entity_id, xxx, ...]``. (if ``ENTITY_ID_FIELD != 'entity_id'``, then change the ``'entity_id'`` in the above example.) Defaults to ``None``. +- ``alias_of_user_id (list)``: List of fields' names, which will be remapped into the same index system with ``USER_ID_FIELD``. Defaults to ``None``. +- ``alias_of_item_id (list)``: List of fields' names, which will be remapped into the same index system with ``ITEM_ID_FIELD``. Defaults to ``None``. +- ``alias_of_entity_id (list)``: List of fields' names, which will be remapped into the same index system with ``ENTITY_ID_FIELD``, ``HEAD_ENTITY_ID_FIELD`` and ``TAIL_ENTITY_ID_FIELD``. Defaults to ``None``. +- ``alias_of_relation_id (list)``: List of fields' names, which will be remapped into the same index system with ``RELATION_ID_FIELD``. Defaults to ``None``. - ``preload_weight (dict)`` : Has the format ``{k (str): v (float)}, ...``. ``k`` if a token field, representing the IDs of each row of preloaded weight matrix. ``v`` is a float like fields. Each pair of ``u`` and ``v`` should be from the same atomic file. This arg can be used to load pretrained vectors. Defaults to ``None``. - ``normalize_field (list)`` : List of filed names to be normalized. Note that only float like fields can be normalized. Defaults to ``None``. - ``normalize_all (bool)`` : Normalize all the float like fields if ``True``. Defaults to ``True``. diff --git a/docs/source/user_guide/model/sequential/gru4reckg.rst b/docs/source/user_guide/model/sequential/gru4reckg.rst index f1653cabf..69dacdaa4 100644 --- a/docs/source/user_guide/model/sequential/gru4reckg.rst +++ b/docs/source/user_guide/model/sequential/gru4reckg.rst @@ -46,9 +46,7 @@ And then: kg: [head_id, relation_id, tail_id] link: [item_id, entity_id] ent_feature: [ent_id, ent_vec] - fields_in_same_space: [ - [ent_id, entity_id] - ] + alias_of_entity_id: [ent_id] preload_weight: ent_id: ent_vec additional_feat_suffix: [ent_feature] diff --git a/docs/source/user_guide/model/sequential/ksr.rst b/docs/source/user_guide/model/sequential/ksr.rst index 8085e71b4..ff24e2184 100644 --- a/docs/source/user_guide/model/sequential/ksr.rst +++ b/docs/source/user_guide/model/sequential/ksr.rst @@ -58,10 +58,8 @@ And then: link: [item_id, entity_id] ent_feature: [ent_id, ent_vec] rel_feature: [rel_id, rel_vec] - fields_in_same_space: [ - [ent_id, entity_id] - [rel_id, relation_id] - ] + alias_of_entity_id: [ent_id] + alias_of_relation_id: [rel_id] preload_weight: ent_id: ent_vec rel_id: rel_vec diff --git a/docs/source/user_guide/usage/load_pretrained_embedding.rst b/docs/source/user_guide/usage/load_pretrained_embedding.rst index c8f3c010a..2e50a31ed 100644 --- a/docs/source/user_guide/usage/load_pretrained_embedding.rst +++ b/docs/source/user_guide/usage/load_pretrained_embedding.rst @@ -22,9 +22,9 @@ Secondly, update the args as (suppose that ``USER_ID_FIELD: user_id``): load_col: # inter/user/item/...: As usual useremb: [uid, user_emb] - fields_in_same_space: [[uid, user_id]] + alias_of_user_id: [uid] preload_weight: - uid: user_emb + uid: user_emb Then, this additional embedding feature file will be loaded into the :class:`Dataset` object. These new features can be accessed as following: @@ -39,6 +39,6 @@ In your model, user embedding matrix can be initialized by your pre-trained embe class YourModel(GeneralRecommender): def __init__(self, config, dataset): - pretrained_user_emb = dataset.get_preload_weight('uid') - self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb)) + pretrained_user_emb = dataset.get_preload_weight('uid') + self.user_embedding = nn.Embedding.from_pretrained(torch.from_numpy(pretrained_user_emb)) diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index fe2d99a35..d5ea9f6d6 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -105,7 +105,7 @@ def _from_scratch(self): self._get_preset() self._get_field_from_config() self._load_data(self.dataset_name, self.dataset_path) - self._get_alias() + self._init_alias() self._data_processing() def _get_preset(self): @@ -118,6 +118,7 @@ def _get_preset(self): self.field2id_token = {} self.field2token_id = {} self.field2seqlen = self.config['seq_len'] or {} + self.alias = {} self._preloaded_weight = {} self.benchmark_filename_list = self.config['benchmark_filename'] @@ -442,36 +443,33 @@ def _load_feat(self, filepath, source): self.field2seqlen[field] = max(map(len, df[field].values)) return df - def _get_alias(self): - """Set :attr:`alias_of_user_id` and :attr:`alias_of_item_id`. - """ - self.alias_of_user_id = self.config['alias_of_user_id'] or [] - self.alias_of_user_id = np.array([self.uid_field] + self.alias_of_user_id) - _, idx = np.unique(self.alias_of_user_id, return_index=True) - self.alias_of_user_id = self.alias_of_user_id[np.sort(idx)] - - self.alias_of_item_id = self.config['alias_of_item_id'] or [] - self.alias_of_item_id = np.array([self.iid_field] + self.alias_of_item_id) - _, idx = np.unique(self.alias_of_item_id, return_index=True) - self.alias_of_item_id = self.alias_of_item_id[np.sort(idx)] - - intersect = np.intersect1d(self.alias_of_user_id, self.alias_of_item_id, assume_unique=True) - if len(intersect) > 0: - raise ValueError(f'`alias_of_user_id` and `alias_of_item_id` ' - f'should not have the same field {list(intersect)}.') - - token_like_fields = self.token_like_fields - user_isin = np.isin(self.alias_of_user_id, token_like_fields, assume_unique=True) - if user_isin.all() is False: - raise ValueError(f'`alias_of_user_id` should not contain ' - f'non-token-like field {list(self.alias_of_user_id[~user_isin])}.') - item_isin = np.isin(self.alias_of_item_id, token_like_fields, assume_unique=True) - if item_isin.all() is False: - raise ValueError(f'`alias_of_item_id` should not contain ' - f'non-token-like field {list(self.alias_of_item_id[~item_isin])}.') - - self._rest_fields = np.setdiff1d(token_like_fields, self.alias_of_user_id, assume_unique=True) - self._rest_fields = np.setdiff1d(self._rest_fields, self.alias_of_item_id, assume_unique=True) + def _set_alias(self, alias_name, default_value): + alias = self.config[f'alias_of_{alias_name}'] or [] + alias = np.array(default_value + alias) + _, idx = np.unique(alias, return_index=True) + self.alias[alias_name] = alias[np.sort(idx)] + + def _init_alias(self): + """Set :attr:`alias_of_user_id` and :attr:`alias_of_item_id`. And set :attr:`_rest_fields`. + """ + self._set_alias('user_id', [self.uid_field]) + self._set_alias('item_id', [self.iid_field]) + + for alias_name_1, alias_1 in self.alias.items(): + for alias_name_2, alias_2 in self.alias.items(): + if alias_name_1 != alias_name_2: + intersect = np.intersect1d(alias_1, alias_2, assume_unique=True) + if len(intersect) > 0: + raise ValueError(f'`alias_of_{alias_name_1}` and `alias_of_{alias_name_2}` ' + f'should not have the same field {list(intersect)}.') + + self._rest_fields = self.token_like_fields + for alias_name, alias in self.alias.items(): + isin = np.isin(alias, self._rest_fields, assume_unique=True) + if isin.all() is False: + raise ValueError(f'`alias_of_{alias_name}` should not contain ' + f'non-token-like field {list(alias[~isin])}.') + self._rest_fields = np.setdiff1d(self._rest_fields, alias, assume_unique=True) def _user_item_feat_preparation(self): """Sort :attr:`user_feat` and :attr:`item_feat` by ``user_id`` or ``item_id``. @@ -921,21 +919,10 @@ def _get_remap_list(self, field_list): def _remap_ID_all(self): """Remap all token-like fields. """ - self._remap_alias() - self._remap_rest() - - def _remap_alias(self): - """Remap :attr:`alias_of_user_id` and :attr:`alias_of_item_id`. - """ - user_remap_list = self._get_remap_list(self.alias_of_user_id) - self._remap(user_remap_list) - - item_remap_list = self._get_remap_list(self.alias_of_item_id) - self._remap(item_remap_list) + for alias in self.alias.values(): + remap_list = self._get_remap_list(alias) + self._remap(remap_list) - def _remap_rest(self): - """Remap other token-like fields. - """ for field in self._rest_fields: remap_list = self._get_remap_list(np.array([field])) self._remap(remap_list) diff --git a/recbole/data/dataset/kg_dataset.py b/recbole/data/dataset/kg_dataset.py index 5298dc608..762141dd7 100644 --- a/recbole/data/dataset/kg_dataset.py +++ b/recbole/data/dataset/kg_dataset.py @@ -186,44 +186,27 @@ def _check_link(self, link): assert self.entity_field in link, link_warn_message.format(self.entity_field) assert self.iid_field in link, link_warn_message.format(self.iid_field) - def _get_alias(self): - """Add :attr:`alias_of_entity_id` and update :attr:`_rest_fields` + def _init_alias(self): + """Add :attr:`alias_of_entity_id`, :attr:`alias_of_relation_id` and update :attr:`_rest_fields`. """ - super()._get_alias() - self.alias_of_entity_id = self.config['alias_of_entity_id'] or [] - self.alias_of_entity_id = np.array([self.head_entity_field, self.tail_entity_field] + self.alias_of_entity_id) - _, idx = np.unique(self.alias_of_entity_id, return_index=True) - self.alias_of_entity_id = self.alias_of_entity_id[np.sort(idx)] - - intersect = np.intersect1d(self.alias_of_entity_id, self.alias_of_user_id, assume_unique=True) - if len(intersect) > 0: - raise ValueError(f'`alias_of_entity_id` and `alias_of_user_id` ' - f'should not have the same field {list(intersect)}.') - intersect = np.intersect1d(self.alias_of_entity_id, self.alias_of_item_id, assume_unique=True) - if len(intersect) > 0: - raise ValueError(f'`alias_of_entity_id` and `alias_of_item_id` ' - f'should not have the same field {list(intersect)}.') - - token_like_fields = self.token_like_fields - entity_isin = np.isin(self.alias_of_entity_id, token_like_fields, assume_unique=True) - if entity_isin.all() is False: - raise ValueError(f'`alias_of_entity_id` should not contain ' - f'non-token-like field {list(self.alias_of_entity_id[~entity_isin])}.') - - self._rest_fields = np.setdiff1d(self._rest_fields, self.alias_of_entity_id, assume_unique=True) + self._set_alias('entity_id', [self.head_entity_field, self.tail_entity_field]) + self._set_alias('relation_id', [self.relation_field]) + + super()._init_alias() + self._rest_fields = np.setdiff1d(self._rest_fields, [self.entity_field], assume_unique=True) def _get_rec_item_token(self): """Get set of entity tokens from fields in ``rec`` level. """ - remap_list = self._get_remap_list(self.alias_of_item_id) + remap_list = self._get_remap_list(self.alias['item_id']) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) def _get_entity_token(self): """Get set of entity tokens from fields in ``ent`` level. """ - remap_list = self._get_remap_list(self.alias_of_entity_id) + remap_list = self._get_remap_list(self.alias['entity_id']) tokens, _ = self._concat_remaped_tokens(remap_list) return set(tokens) @@ -261,7 +244,7 @@ def _merge_item_and_entity(self): item_id_map[item_order] = np.arange(item_num) new_item_id2token = item_token[item_order] new_item_token2id = {t: i for i, t in enumerate(new_item_id2token)} - for field in self.alias_of_item_id: + for field in self.alias['item_id']: self._reset_ent_remapID(field, item_id_map, new_item_id2token, new_item_token2id) # reset entity id @@ -275,25 +258,17 @@ def _merge_item_and_entity(self): for i in range(item_num - link_num, item_num): new_entity_id2token[i] = self.item2entity[new_entity_id2token[i]] new_entity_token2id = {t: i for i, t in enumerate(new_entity_id2token)} - for field in self.alias_of_entity_id: + for field in self.alias['entity_id']: self._reset_ent_remapID(field, entity_id_map, new_entity_id2token, new_entity_token2id) self.field2id_token[self.entity_field] = new_entity_id2token self.field2token_id[self.entity_field] = new_entity_token2id def _remap_ID_all(self): super()._remap_ID_all() + self._merge_item_and_entity() self.field2token_id[self.relation_field]['[UI-Relation]'] = len(self.field2id_token[self.relation_field]) self.field2id_token[self.relation_field] = np.append(self.field2id_token[self.relation_field], '[UI-Relation]') - def _remap_alias(self): - """Remap :attr:`alias_of_entity_id` additionally. - """ - super()._remap_alias() - entity_remap_list = self._get_remap_list(self.alias_of_entity_id) - self._remap(entity_remap_list) - - self._merge_item_and_entity() - @property def relation_num(self): """Get the number of different tokens of ``self.relation_field``. diff --git a/recbole/properties/dataset/ml-100k.yaml b/recbole/properties/dataset/ml-100k.yaml index 796bb5e24..7c229f814 100644 --- a/recbole/properties/dataset/ml-100k.yaml +++ b/recbole/properties/dataset/ml-100k.yaml @@ -40,6 +40,7 @@ item_inter_num_interval: ~ alias_of_user_id: ~ alias_of_item_id: ~ alias_of_entity_id: ~ +alias_of_relation_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: True diff --git a/recbole/properties/dataset/sample.yaml b/recbole/properties/dataset/sample.yaml index 156942110..181d014d1 100644 --- a/recbole/properties/dataset/sample.yaml +++ b/recbole/properties/dataset/sample.yaml @@ -40,6 +40,7 @@ min_item_inter_num: 0 alias_of_user_id: ~ alias_of_item_id: ~ alias_of_entity_id: ~ +alias_of_relation_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: ~ diff --git a/recbole/utils/argument_list.py b/recbole/utils/argument_list.py index 43e8db305..fc5f1ae1f 100644 --- a/recbole/utils/argument_list.py +++ b/recbole/utils/argument_list.py @@ -47,7 +47,7 @@ 'load_col', 'unload_col', 'unused_col', 'additional_feat_suffix', 'max_user_inter_num', 'min_user_inter_num', 'max_item_inter_num', 'min_item_inter_num', 'lowest_val', 'highest_val', 'equal_val', 'not_equal_val', - 'alias_of_user_id', 'alias_of_item_id', 'alias_of_entity_id', + 'alias_of_user_id', 'alias_of_item_id', 'alias_of_entity_id', 'alias_of_relation_id', 'preload_weight', 'normalize_field', 'normalize_all' ] diff --git a/tests/model/test_model.yaml b/tests/model/test_model.yaml index 3f366fe2c..a32220742 100644 --- a/tests/model/test_model.yaml +++ b/tests/model/test_model.yaml @@ -52,6 +52,7 @@ not_equal_val: ~ alias_of_user_id: ~ alias_of_item_id: ~ alias_of_entity_id: ~ +alias_of_relation_id: ~ preload_weight: ~ normalize_field: ~ normalize_all: True