From b799291b961ba7861512d847a283f2becd8f5e2b Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Fri, 7 May 2021 18:12:48 +0800 Subject: [PATCH] FIX: fix the bug of len(kg_dataset.field2id_token['item_id']) != len(kg_dataset.field2id_token['item_id']). (issue #823) --- recbole/data/dataset/kg_dataset.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/recbole/data/dataset/kg_dataset.py b/recbole/data/dataset/kg_dataset.py index 401768f0d..38b30387c 100644 --- a/recbole/data/dataset/kg_dataset.py +++ b/recbole/data/dataset/kg_dataset.py @@ -336,15 +336,20 @@ def _sort_remaped_entities(self, item_tokens): layered_num.append(order_cnt[i]) layered_num = np.cumsum(np.array(layered_num)) new_id_token = item_ent_token_list[:layered_num[-2]] + new_token_id = {t: i for i, t in enumerate(new_id_token)} for field in self.rec_level_ent_fields: self._reset_ent_remapID(field, new_id_token) self.field2id_token[field] = new_id_token + self.field2token_id[field] = new_token_id new_id_token = item_ent_token_list[:layered_num[-1]] new_id_token = [self.item2entity[_] if _ in self.item2entity else _ for _ in new_id_token] + new_token_id = {t: i for i, t in enumerate(new_id_token)} for field in self.ent_level_ent_fields: self._reset_ent_remapID(field, item_ent_token_list[:layered_num[-1]]) self.field2id_token[field] = new_id_token - self.field2id_token[self.entity_field] = item_ent_token_list[:layered_num[-1]] + self.field2token_id[field] = new_token_id + self.field2id_token[self.entity_field] = new_id_token + self.field2token_id[self.entity_field] = new_token_id def _remap_ID_all(self): """Firstly, remap entities and items all together. Then sort entity tokens, @@ -354,6 +359,7 @@ def _remap_ID_all(self): item_tokens = self._get_rec_item_token() super()._remap_ID_all() self._sort_remaped_entities(item_tokens) + self.field2token_id[self.relation_field]['[UI-Relation]'] = len(self.field2id_token[self.relation_field]) self.field2id_token[self.relation_field] = np.append(self.field2id_token[self.relation_field], '[UI-Relation]') @property