From 974f27ecacfe7abeb33b13f57283ff8336497859 Mon Sep 17 00:00:00 2001 From: zhengbw0324 <18735382001@163.com> Date: Sat, 20 Aug 2022 17:27:59 +0800 Subject: [PATCH 1/6] Modified the interaction history matrix generation function to support setting the maximum number of history interaction records. --- recbole/data/dataset/dataset.py | 39 ++++++++++++++------ recbole/model/general_recommender/simplex.py | 30 ++++++++------- 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index 6bf88ce17..580a14c20 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -1954,7 +1954,7 @@ def inter_matrix(self, form="coo", value_field=None): self.inter_feat, self.uid_field, self.iid_field, form, value_field ) - def _history_matrix(self, row, value_field=None): + def _history_matrix(self, row, value_field=None ,max_history_len=None): """Get dense matrix describe user/item's history interaction records. ``history_matrix[i]`` represents ``i``'s history interacted item_id. @@ -1970,6 +1970,8 @@ def _history_matrix(self, row, value_field=None): row (str): ``user`` or ``item``. value_field (str, optional): Data of matrix, which should exist in ``self.inter_feat``. Defaults to ``None``. + max_history_len (int): The maximum number of history interaction records. + Defaults to ``None``. Returns: tuple: @@ -1979,18 +1981,20 @@ def _history_matrix(self, row, value_field=None): """ self._check_field("uid_field", "iid_field") + inter_feat= copy.deepcopy(self.inter_feat) + inter_feat.shuffle() user_ids, item_ids = ( - self.inter_feat[self.uid_field].numpy(), - self.inter_feat[self.iid_field].numpy(), + inter_feat[self.uid_field].numpy(), + inter_feat[self.iid_field].numpy(), ) if value_field is None: - values = np.ones(len(self.inter_feat)) + values = np.ones(len(inter_feat)) else: - if value_field not in self.inter_feat: + if value_field not in inter_feat: raise ValueError( f"Value_field [{value_field}] should be one of `inter_feat`'s features." ) - values = self.inter_feat[value_field].numpy() + values = inter_feat[value_field].numpy() if row == "user": row_num, max_col_num = self.user_num, self.item_num @@ -2003,7 +2007,12 @@ def _history_matrix(self, row, value_field=None): for row_id in row_ids: history_len[row_id] += 1 - col_num = np.max(history_len) + max_inter_num=np.max(history_len) + if max_history_len is not None: + col_num= min(max_history_len, max_inter_num) + else: + col_num = max_inter_num + if col_num > max_col_num * 0.2: self.logger.warning( f"Max value of {row}'s history interaction records has reached " @@ -2014,6 +2023,8 @@ def _history_matrix(self, row, value_field=None): history_value = np.zeros((row_num, col_num)) history_len[:] = 0 for row_id, value, col_id in zip(row_ids, values, col_ids): + if history_len[row_id] >= col_num: + continue history_matrix[row_id, history_len[row_id]] = col_id history_value[row_id, history_len[row_id]] = value history_len[row_id] += 1 @@ -2024,7 +2035,7 @@ def _history_matrix(self, row, value_field=None): torch.LongTensor(history_len), ) - def history_item_matrix(self, value_field=None): + def history_item_matrix(self, value_field=None, max_history_len=None): """Get dense matrix describe user's history interaction records. ``history_matrix[i]`` represents user ``i``'s history interacted item_id. @@ -2040,15 +2051,18 @@ def history_item_matrix(self, value_field=None): value_field (str, optional): Data of matrix, which should exist in ``self.inter_feat``. Defaults to ``None``. + max_history_len (int): The maximum number of user's history interaction records. + Defaults to ``None``. + Returns: tuple: - History matrix (torch.Tensor): ``history_matrix`` described above. - History values matrix (torch.Tensor): ``history_value`` described above. - History length matrix (torch.Tensor): ``history_len`` described above. """ - return self._history_matrix(row="user", value_field=value_field) + return self._history_matrix(row="user", value_field=value_field, max_history_len=max_history_len) - def history_user_matrix(self, value_field=None): + def history_user_matrix(self, value_field=None, max_history_len=None): """Get dense matrix describe item's history interaction records. ``history_matrix[i]`` represents item ``i``'s history interacted item_id. @@ -2064,13 +2078,16 @@ def history_user_matrix(self, value_field=None): value_field (str, optional): Data of matrix, which should exist in ``self.inter_feat``. Defaults to ``None``. + max_history_len (int): The maximum number of item's history interaction records. + Defaults to ``None``. + Returns: tuple: - History matrix (torch.Tensor): ``history_matrix`` described above. - History values matrix (torch.Tensor): ``history_value`` described above. - History length matrix (torch.Tensor): ``history_len`` described above. """ - return self._history_matrix(row="item", value_field=value_field) + return self._history_matrix(row="item", value_field=value_field, max_history_len=max_history_len) def get_preload_weight(self, field): """Get preloaded weight matrix, whose rows are sorted by token ids. diff --git a/recbole/model/general_recommender/simplex.py b/recbole/model/general_recommender/simplex.py index 10927911e..7e399e21e 100644 --- a/recbole/model/general_recommender/simplex.py +++ b/recbole/model/general_recommender/simplex.py @@ -40,10 +40,11 @@ def __init__(self, config, dataset): super(SimpleX, self).__init__(config, dataset) # Get user transaction history - self.history_item_id, _, self.history_item_len = dataset.history_item_matrix() + self.history_item_id, _, self.history_item_len = dataset.history_item_matrix(max_history_len=config["history_len"]) self.history_item_id = self.history_item_id.to(self.device) self.history_item_len = self.history_item_len.to(self.device) + # load parameters info self.embedding_size = config["embedding_size"] self.margin = config["margin"] @@ -56,7 +57,7 @@ def __init__(self, config, dataset): raise ValueError( "aggregator must be mean, user_attention or self_attention" ) - self.history_len = min(config["history_len"], self.history_item_len.shape[0]) + self.history_len = torch.max(self.history_item_len,dim=0) # user embedding matrix self.user_emb = nn.Embedding(self.n_users, self.embedding_size) @@ -206,14 +207,15 @@ def calculate_loss(self, interaction): user = user[0:user_number] # historical transaction record history_item = self.history_item_id[user] - history_item = history_item[:, : self.history_len] + # history_item = history_item[:, : self.history_len] + # positive item's id pos_item = pos_item[0:user_number] # history_len history_len = self.history_item_len[user] - history_len = torch.minimum( - history_len, torch.zeros(1, device=self.device) + self.history_len - ) + # history_len = torch.minimum( + # history_len, torch.zeros(1, device=self.device) + self.history_len + # ) loss = self.forward(user, pos_item, history_item, history_len, neg_item_seq) return loss @@ -221,11 +223,11 @@ def calculate_loss(self, interaction): def predict(self, interaction): user = interaction[self.USER_ID] history_item = self.history_item_id[user] - history_item = history_item[:, : self.history_len] + # history_item = history_item[:, : self.history_len] history_len = self.history_item_len[user] - history_len = torch.minimum( - history_len, torch.zeros(1, device=self.device) + self.history_len - ) + # history_len = torch.minimum( + # history_len, torch.zeros(1, device=self.device) + self.history_len + # ) test_item = interaction[self.ITEM_ID] # [user_num, embedding_size] @@ -244,11 +246,11 @@ def predict(self, interaction): def full_sort_predict(self, interaction): user = interaction[self.USER_ID] history_item = self.history_item_id[user] - history_item = history_item[:, : self.history_len] + # history_item = history_item[:, : self.history_len] history_len = self.history_item_len[user] - history_len = torch.minimum( - history_len, torch.zeros(1, device=self.device) + self.history_len - ) + # history_len = torch.minimum( + # history_len, torch.zeros(1, device=self.device) + self.history_len + # ) # [user_num, embedding_size] user_e = self.user_emb(user) From 93a7f56573ba3a59bf1dba92f045f618b33c8bd2 Mon Sep 17 00:00:00 2001 From: zhengbw0324 <18735382001@163.com> Date: Sat, 20 Aug 2022 17:32:56 +0800 Subject: [PATCH 2/6] The latest interaction history matrix function is used. --- recbole/model/general_recommender/simplex.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/recbole/model/general_recommender/simplex.py b/recbole/model/general_recommender/simplex.py index 7e399e21e..667a497cd 100644 --- a/recbole/model/general_recommender/simplex.py +++ b/recbole/model/general_recommender/simplex.py @@ -207,15 +207,10 @@ def calculate_loss(self, interaction): user = user[0:user_number] # historical transaction record history_item = self.history_item_id[user] - # history_item = history_item[:, : self.history_len] - # positive item's id pos_item = pos_item[0:user_number] # history_len history_len = self.history_item_len[user] - # history_len = torch.minimum( - # history_len, torch.zeros(1, device=self.device) + self.history_len - # ) loss = self.forward(user, pos_item, history_item, history_len, neg_item_seq) return loss @@ -223,11 +218,7 @@ def calculate_loss(self, interaction): def predict(self, interaction): user = interaction[self.USER_ID] history_item = self.history_item_id[user] - # history_item = history_item[:, : self.history_len] history_len = self.history_item_len[user] - # history_len = torch.minimum( - # history_len, torch.zeros(1, device=self.device) + self.history_len - # ) test_item = interaction[self.ITEM_ID] # [user_num, embedding_size] @@ -246,11 +237,7 @@ def predict(self, interaction): def full_sort_predict(self, interaction): user = interaction[self.USER_ID] history_item = self.history_item_id[user] - # history_item = history_item[:, : self.history_len] history_len = self.history_item_len[user] - # history_len = torch.minimum( - # history_len, torch.zeros(1, device=self.device) + self.history_len - # ) # [user_num, embedding_size] user_e = self.user_emb(user) From 3186ee241b72bee345a8f77f59b5c083415d78be Mon Sep 17 00:00:00 2001 From: zhengbw0324 <18735382001@163.com> Date: Sat, 20 Aug 2022 18:30:15 +0800 Subject: [PATCH 3/6] The latest interaction history matrix function is used. --- recbole/model/general_recommender/simplex.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recbole/model/general_recommender/simplex.py b/recbole/model/general_recommender/simplex.py index 667a497cd..5f17f48ef 100644 --- a/recbole/model/general_recommender/simplex.py +++ b/recbole/model/general_recommender/simplex.py @@ -39,7 +39,7 @@ class SimpleX(GeneralRecommender): def __init__(self, config, dataset): super(SimpleX, self).__init__(config, dataset) - # Get user transaction history + # Get user history interacted items self.history_item_id, _, self.history_item_len = dataset.history_item_matrix(max_history_len=config["history_len"]) self.history_item_id = self.history_item_id.to(self.device) self.history_item_len = self.history_item_len.to(self.device) From dc36d4a15834082df678255a9ac0362d871f27e7 Mon Sep 17 00:00:00 2001 From: zhengbw0324 <18735382001@163.com> Date: Sat, 20 Aug 2022 18:36:14 +0800 Subject: [PATCH 4/6] Format Python code according to PEP8 --- recbole/trainer/hyper_tuning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/recbole/trainer/hyper_tuning.py b/recbole/trainer/hyper_tuning.py index bfe981761..76552be81 100644 --- a/recbole/trainer/hyper_tuning.py +++ b/recbole/trainer/hyper_tuning.py @@ -114,6 +114,7 @@ def exhaustive_search(new_ids, domain, trials, seed, nbMaxSucessiveFailures=1000 ) rng = np.random.RandomState(seed) + # rng = np.random.default_rng(seed) rval = [] for _, new_id in enumerate(new_ids): newSample = False From 1e02255c08268b90971c21073d3e34de0602884a Mon Sep 17 00:00:00 2001 From: zhengbw0324 <18735382001@163.com> Date: Sat, 20 Aug 2022 18:42:25 +0800 Subject: [PATCH 5/6] . --- recbole/trainer/hyper_tuning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/recbole/trainer/hyper_tuning.py b/recbole/trainer/hyper_tuning.py index 76552be81..bfe981761 100644 --- a/recbole/trainer/hyper_tuning.py +++ b/recbole/trainer/hyper_tuning.py @@ -114,7 +114,6 @@ def exhaustive_search(new_ids, domain, trials, seed, nbMaxSucessiveFailures=1000 ) rng = np.random.RandomState(seed) - # rng = np.random.default_rng(seed) rval = [] for _, new_id in enumerate(new_ids): newSample = False From 5dedee0b44f011c1c9bce22735c1914b2efda9ac Mon Sep 17 00:00:00 2001 From: zhengbw0324 <18735382001@163.com> Date: Sat, 20 Aug 2022 18:51:24 +0800 Subject: [PATCH 6/6] Format Python code according to PEP8 --- run_hyper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/run_hyper.py b/run_hyper.py index ccd45cca1..71db743eb 100644 --- a/run_hyper.py +++ b/run_hyper.py @@ -95,6 +95,7 @@ def ray_tune(args): ) parser.add_argument("--tool", type=str, default="Hyperopt", help="tuning tool") args, _ = parser.parse_known_args() + if args.tool == "Hyperopt": hyperopt_tune(args) elif args.tool == "Ray":