diff --git a/recbole/data/dataset/dataset.py b/recbole/data/dataset/dataset.py index 6bf88ce17..580a14c20 100644 --- a/recbole/data/dataset/dataset.py +++ b/recbole/data/dataset/dataset.py @@ -1954,7 +1954,7 @@ def inter_matrix(self, form="coo", value_field=None): self.inter_feat, self.uid_field, self.iid_field, form, value_field ) - def _history_matrix(self, row, value_field=None): + def _history_matrix(self, row, value_field=None ,max_history_len=None): """Get dense matrix describe user/item's history interaction records. ``history_matrix[i]`` represents ``i``'s history interacted item_id. @@ -1970,6 +1970,8 @@ def _history_matrix(self, row, value_field=None): row (str): ``user`` or ``item``. value_field (str, optional): Data of matrix, which should exist in ``self.inter_feat``. Defaults to ``None``. + max_history_len (int): The maximum number of history interaction records. + Defaults to ``None``. Returns: tuple: @@ -1979,18 +1981,20 @@ def _history_matrix(self, row, value_field=None): """ self._check_field("uid_field", "iid_field") + inter_feat= copy.deepcopy(self.inter_feat) + inter_feat.shuffle() user_ids, item_ids = ( - self.inter_feat[self.uid_field].numpy(), - self.inter_feat[self.iid_field].numpy(), + inter_feat[self.uid_field].numpy(), + inter_feat[self.iid_field].numpy(), ) if value_field is None: - values = np.ones(len(self.inter_feat)) + values = np.ones(len(inter_feat)) else: - if value_field not in self.inter_feat: + if value_field not in inter_feat: raise ValueError( f"Value_field [{value_field}] should be one of `inter_feat`'s features." ) - values = self.inter_feat[value_field].numpy() + values = inter_feat[value_field].numpy() if row == "user": row_num, max_col_num = self.user_num, self.item_num @@ -2003,7 +2007,12 @@ def _history_matrix(self, row, value_field=None): for row_id in row_ids: history_len[row_id] += 1 - col_num = np.max(history_len) + max_inter_num=np.max(history_len) + if max_history_len is not None: + col_num= min(max_history_len, max_inter_num) + else: + col_num = max_inter_num + if col_num > max_col_num * 0.2: self.logger.warning( f"Max value of {row}'s history interaction records has reached " @@ -2014,6 +2023,8 @@ def _history_matrix(self, row, value_field=None): history_value = np.zeros((row_num, col_num)) history_len[:] = 0 for row_id, value, col_id in zip(row_ids, values, col_ids): + if history_len[row_id] >= col_num: + continue history_matrix[row_id, history_len[row_id]] = col_id history_value[row_id, history_len[row_id]] = value history_len[row_id] += 1 @@ -2024,7 +2035,7 @@ def _history_matrix(self, row, value_field=None): torch.LongTensor(history_len), ) - def history_item_matrix(self, value_field=None): + def history_item_matrix(self, value_field=None, max_history_len=None): """Get dense matrix describe user's history interaction records. ``history_matrix[i]`` represents user ``i``'s history interacted item_id. @@ -2040,15 +2051,18 @@ def history_item_matrix(self, value_field=None): value_field (str, optional): Data of matrix, which should exist in ``self.inter_feat``. Defaults to ``None``. + max_history_len (int): The maximum number of user's history interaction records. + Defaults to ``None``. + Returns: tuple: - History matrix (torch.Tensor): ``history_matrix`` described above. - History values matrix (torch.Tensor): ``history_value`` described above. - History length matrix (torch.Tensor): ``history_len`` described above. """ - return self._history_matrix(row="user", value_field=value_field) + return self._history_matrix(row="user", value_field=value_field, max_history_len=max_history_len) - def history_user_matrix(self, value_field=None): + def history_user_matrix(self, value_field=None, max_history_len=None): """Get dense matrix describe item's history interaction records. ``history_matrix[i]`` represents item ``i``'s history interacted item_id. @@ -2064,13 +2078,16 @@ def history_user_matrix(self, value_field=None): value_field (str, optional): Data of matrix, which should exist in ``self.inter_feat``. Defaults to ``None``. + max_history_len (int): The maximum number of item's history interaction records. + Defaults to ``None``. + Returns: tuple: - History matrix (torch.Tensor): ``history_matrix`` described above. - History values matrix (torch.Tensor): ``history_value`` described above. - History length matrix (torch.Tensor): ``history_len`` described above. """ - return self._history_matrix(row="item", value_field=value_field) + return self._history_matrix(row="item", value_field=value_field, max_history_len=max_history_len) def get_preload_weight(self, field): """Get preloaded weight matrix, whose rows are sorted by token ids. diff --git a/recbole/model/general_recommender/simplex.py b/recbole/model/general_recommender/simplex.py index 10927911e..5f17f48ef 100644 --- a/recbole/model/general_recommender/simplex.py +++ b/recbole/model/general_recommender/simplex.py @@ -39,11 +39,12 @@ class SimpleX(GeneralRecommender): def __init__(self, config, dataset): super(SimpleX, self).__init__(config, dataset) - # Get user transaction history - self.history_item_id, _, self.history_item_len = dataset.history_item_matrix() + # Get user history interacted items + self.history_item_id, _, self.history_item_len = dataset.history_item_matrix(max_history_len=config["history_len"]) self.history_item_id = self.history_item_id.to(self.device) self.history_item_len = self.history_item_len.to(self.device) + # load parameters info self.embedding_size = config["embedding_size"] self.margin = config["margin"] @@ -56,7 +57,7 @@ def __init__(self, config, dataset): raise ValueError( "aggregator must be mean, user_attention or self_attention" ) - self.history_len = min(config["history_len"], self.history_item_len.shape[0]) + self.history_len = torch.max(self.history_item_len,dim=0) # user embedding matrix self.user_emb = nn.Embedding(self.n_users, self.embedding_size) @@ -206,14 +207,10 @@ def calculate_loss(self, interaction): user = user[0:user_number] # historical transaction record history_item = self.history_item_id[user] - history_item = history_item[:, : self.history_len] # positive item's id pos_item = pos_item[0:user_number] # history_len history_len = self.history_item_len[user] - history_len = torch.minimum( - history_len, torch.zeros(1, device=self.device) + self.history_len - ) loss = self.forward(user, pos_item, history_item, history_len, neg_item_seq) return loss @@ -221,11 +218,7 @@ def calculate_loss(self, interaction): def predict(self, interaction): user = interaction[self.USER_ID] history_item = self.history_item_id[user] - history_item = history_item[:, : self.history_len] history_len = self.history_item_len[user] - history_len = torch.minimum( - history_len, torch.zeros(1, device=self.device) + self.history_len - ) test_item = interaction[self.ITEM_ID] # [user_num, embedding_size] @@ -244,11 +237,7 @@ def predict(self, interaction): def full_sort_predict(self, interaction): user = interaction[self.USER_ID] history_item = self.history_item_id[user] - history_item = history_item[:, : self.history_len] history_len = self.history_item_len[user] - history_len = torch.minimum( - history_len, torch.zeros(1, device=self.device) + self.history_len - ) # [user_num, embedding_size] user_e = self.user_emb(user) diff --git a/run_hyper.py b/run_hyper.py index ccd45cca1..71db743eb 100644 --- a/run_hyper.py +++ b/run_hyper.py @@ -95,6 +95,7 @@ def ray_tune(args): ) parser.add_argument("--tool", type=str, default="Hyperopt", help="tuning tool") args, _ = parser.parse_known_args() + if args.tool == "Hyperopt": hyperopt_tune(args) elif args.tool == "Ray":