Skip to content

Commit

Permalink
Merge pull request #1405 from zhengbw0324/1.1.x
Browse files Browse the repository at this point in the history
Update the funtion of history_item_matrix
  • Loading branch information
Sherry-XLL authored Aug 20, 2022
2 parents f7b3d8a + 5dedee0 commit 6a54203
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 26 deletions.
39 changes: 28 additions & 11 deletions recbole/data/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,7 +1954,7 @@ def inter_matrix(self, form="coo", value_field=None):
self.inter_feat, self.uid_field, self.iid_field, form, value_field
)

def _history_matrix(self, row, value_field=None):
def _history_matrix(self, row, value_field=None ,max_history_len=None):
"""Get dense matrix describe user/item's history interaction records.
``history_matrix[i]`` represents ``i``'s history interacted item_id.
Expand All @@ -1970,6 +1970,8 @@ def _history_matrix(self, row, value_field=None):
row (str): ``user`` or ``item``.
value_field (str, optional): Data of matrix, which should exist in ``self.inter_feat``.
Defaults to ``None``.
max_history_len (int): The maximum number of history interaction records.
Defaults to ``None``.
Returns:
tuple:
Expand All @@ -1979,18 +1981,20 @@ def _history_matrix(self, row, value_field=None):
"""
self._check_field("uid_field", "iid_field")

inter_feat= copy.deepcopy(self.inter_feat)
inter_feat.shuffle()
user_ids, item_ids = (
self.inter_feat[self.uid_field].numpy(),
self.inter_feat[self.iid_field].numpy(),
inter_feat[self.uid_field].numpy(),
inter_feat[self.iid_field].numpy(),
)
if value_field is None:
values = np.ones(len(self.inter_feat))
values = np.ones(len(inter_feat))
else:
if value_field not in self.inter_feat:
if value_field not in inter_feat:
raise ValueError(
f"Value_field [{value_field}] should be one of `inter_feat`'s features."
)
values = self.inter_feat[value_field].numpy()
values = inter_feat[value_field].numpy()

if row == "user":
row_num, max_col_num = self.user_num, self.item_num
Expand All @@ -2003,7 +2007,12 @@ def _history_matrix(self, row, value_field=None):
for row_id in row_ids:
history_len[row_id] += 1

col_num = np.max(history_len)
max_inter_num=np.max(history_len)
if max_history_len is not None:
col_num= min(max_history_len, max_inter_num)
else:
col_num = max_inter_num

if col_num > max_col_num * 0.2:
self.logger.warning(
f"Max value of {row}'s history interaction records has reached "
Expand All @@ -2014,6 +2023,8 @@ def _history_matrix(self, row, value_field=None):
history_value = np.zeros((row_num, col_num))
history_len[:] = 0
for row_id, value, col_id in zip(row_ids, values, col_ids):
if history_len[row_id] >= col_num:
continue
history_matrix[row_id, history_len[row_id]] = col_id
history_value[row_id, history_len[row_id]] = value
history_len[row_id] += 1
Expand All @@ -2024,7 +2035,7 @@ def _history_matrix(self, row, value_field=None):
torch.LongTensor(history_len),
)

def history_item_matrix(self, value_field=None):
def history_item_matrix(self, value_field=None, max_history_len=None):
"""Get dense matrix describe user's history interaction records.
``history_matrix[i]`` represents user ``i``'s history interacted item_id.
Expand All @@ -2040,15 +2051,18 @@ def history_item_matrix(self, value_field=None):
value_field (str, optional): Data of matrix, which should exist in ``self.inter_feat``.
Defaults to ``None``.
max_history_len (int): The maximum number of user's history interaction records.
Defaults to ``None``.
Returns:
tuple:
- History matrix (torch.Tensor): ``history_matrix`` described above.
- History values matrix (torch.Tensor): ``history_value`` described above.
- History length matrix (torch.Tensor): ``history_len`` described above.
"""
return self._history_matrix(row="user", value_field=value_field)
return self._history_matrix(row="user", value_field=value_field, max_history_len=max_history_len)

def history_user_matrix(self, value_field=None):
def history_user_matrix(self, value_field=None, max_history_len=None):
"""Get dense matrix describe item's history interaction records.
``history_matrix[i]`` represents item ``i``'s history interacted item_id.
Expand All @@ -2064,13 +2078,16 @@ def history_user_matrix(self, value_field=None):
value_field (str, optional): Data of matrix, which should exist in ``self.inter_feat``.
Defaults to ``None``.
max_history_len (int): The maximum number of item's history interaction records.
Defaults to ``None``.
Returns:
tuple:
- History matrix (torch.Tensor): ``history_matrix`` described above.
- History values matrix (torch.Tensor): ``history_value`` described above.
- History length matrix (torch.Tensor): ``history_len`` described above.
"""
return self._history_matrix(row="item", value_field=value_field)
return self._history_matrix(row="item", value_field=value_field, max_history_len=max_history_len)

def get_preload_weight(self, field):
"""Get preloaded weight matrix, whose rows are sorted by token ids.
Expand Down
19 changes: 4 additions & 15 deletions recbole/model/general_recommender/simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ class SimpleX(GeneralRecommender):
def __init__(self, config, dataset):
super(SimpleX, self).__init__(config, dataset)

# Get user transaction history
self.history_item_id, _, self.history_item_len = dataset.history_item_matrix()
# Get user history interacted items
self.history_item_id, _, self.history_item_len = dataset.history_item_matrix(max_history_len=config["history_len"])
self.history_item_id = self.history_item_id.to(self.device)
self.history_item_len = self.history_item_len.to(self.device)


# load parameters info
self.embedding_size = config["embedding_size"]
self.margin = config["margin"]
Expand All @@ -56,7 +57,7 @@ def __init__(self, config, dataset):
raise ValueError(
"aggregator must be mean, user_attention or self_attention"
)
self.history_len = min(config["history_len"], self.history_item_len.shape[0])
self.history_len = torch.max(self.history_item_len,dim=0)

# user embedding matrix
self.user_emb = nn.Embedding(self.n_users, self.embedding_size)
Expand Down Expand Up @@ -206,26 +207,18 @@ def calculate_loss(self, interaction):
user = user[0:user_number]
# historical transaction record
history_item = self.history_item_id[user]
history_item = history_item[:, : self.history_len]
# positive item's id
pos_item = pos_item[0:user_number]
# history_len
history_len = self.history_item_len[user]
history_len = torch.minimum(
history_len, torch.zeros(1, device=self.device) + self.history_len
)

loss = self.forward(user, pos_item, history_item, history_len, neg_item_seq)
return loss

def predict(self, interaction):
user = interaction[self.USER_ID]
history_item = self.history_item_id[user]
history_item = history_item[:, : self.history_len]
history_len = self.history_item_len[user]
history_len = torch.minimum(
history_len, torch.zeros(1, device=self.device) + self.history_len
)
test_item = interaction[self.ITEM_ID]

# [user_num, embedding_size]
Expand All @@ -244,11 +237,7 @@ def predict(self, interaction):
def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
history_item = self.history_item_id[user]
history_item = history_item[:, : self.history_len]
history_len = self.history_item_len[user]
history_len = torch.minimum(
history_len, torch.zeros(1, device=self.device) + self.history_len
)

# [user_num, embedding_size]
user_e = self.user_emb(user)
Expand Down
1 change: 1 addition & 0 deletions run_hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def ray_tune(args):
)
parser.add_argument("--tool", type=str, default="Hyperopt", help="tuning tool")
args, _ = parser.parse_known_args()

if args.tool == "Hyperopt":
hyperopt_tune(args)
elif args.tool == "Ray":
Expand Down

0 comments on commit 6a54203

Please sign in to comment.