Skip to content

Commit

Permalink
Merge pull request #1025 from chenyushuo/master
Browse files Browse the repository at this point in the history
FIX: bug fix in neg sampling and `squeeze` function.
  • Loading branch information
2017pxy authored Nov 1, 2021
2 parents d976662 + f04084b commit 4b6c6fe
Show file tree
Hide file tree
Showing 25 changed files with 36 additions and 44 deletions.
4 changes: 2 additions & 2 deletions recbole/data/dataloader/abstract_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def _set_neg_sample_args(self, config, dataset, dl_format, neg_sample_args):

def _neg_sampling(self, inter_feat):
if self.neg_sample_args['strategy'] == 'by':
user_ids = inter_feat[self.uid_field]
item_ids = inter_feat[self.iid_field]
user_ids = inter_feat[self.uid_field].numpy()
item_ids = inter_feat[self.iid_field].numpy()
neg_item_ids = self.sampler.sample_by_user_ids(user_ids, item_ids, self.neg_sample_num)
return self.sampling_func(inter_feat, neg_item_ids)
else:
Expand Down
2 changes: 1 addition & 1 deletion recbole/data/dataloader/knowledge_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _shuffle(self):

def _next_batch_data(self):
cur_data = self.dataset.kg_feat[self.pr:self.pr + self.step]
head_ids = cur_data[self.hid_field]
head_ids = cur_data[self.hid_field].numpy()
neg_tail_ids = self.sampler.sample_by_entity_ids(head_ids, self.neg_sample_num)
cur_data.update(Interaction({self.neg_tid_field: neg_tail_ids}))
self.pr += self.step
Expand Down
4 changes: 2 additions & 2 deletions recbole/evaluator/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def used_info(self, dataobject):
"""
rec_mat = dataobject.get('rec.topk')
topk_idx, pos_len_list = torch.split(rec_mat, [max(self.topk), 1], dim=1)
return rec_mat.to(torch.bool).numpy(), pos_len_list.squeeze().numpy()
return rec_mat.to(torch.bool).numpy(), pos_len_list.squeeze(-1).numpy()

def topk_result(self, metric, value):
"""Match the metric value to the `k` and put them in `dictionary` form.
Expand Down Expand Up @@ -111,7 +111,7 @@ def used_info(self, dataobject):
preds = dataobject.get('rec.score')
trues = dataobject.get('data.label')

return preds.squeeze().numpy(), trues.squeeze().numpy()
return preds.squeeze(-1).numpy(), trues.squeeze(-1).numpy()

def output_metric(self, metric, dataobject):
preds, trues = self.used_info(dataobject)
Expand Down
2 changes: 1 addition & 1 deletion recbole/evaluator/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def __init__(self, config):
def calculate_metric(self, dataobject):
mean_rank = dataobject.get('rec.meanrank').numpy()
pos_rank_sum, user_len_list, pos_len_list = np.split(mean_rank, 3, axis=1)
user_len_list, pos_len_list = user_len_list.squeeze(), pos_len_list.squeeze()
user_len_list, pos_len_list = user_len_list.squeeze(-1), pos_len_list.squeeze(-1)
result = self.metric_info(pos_rank_sum, user_len_list, pos_len_list)
return {'gauc': round(result, self.decimal_place)}

Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/afm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def forward(self, interaction):
afm_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim]

output = self.sigmoid(self.first_order_linear(interaction) + self.afm_layer(afm_all_embeddings))
return output.squeeze()
return output.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward(self, interaction):

y_deep = self.deep_predict_layer(self.mlp_layers(deepfm_all_embeddings.view(batch_size, -1)))
y = self.sigmoid(y_fm + y_deep)
return y.squeeze()
return y.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/dssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def forward(self, interaction):
score = torch.cosine_similarity(user_dnn_out, item_dnn_out, dim=1)

sig_score = self.sigmoid(score)
return sig_score.squeeze()
return sig_score.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/ffm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward(self, interaction):
ffm_output = torch.sum(torch.sum(self.ffm(ffm_input), dim=1), dim=1, keepdim=True)
output = self.sigmoid(self.first_order_linear(interaction) + ffm_output)

return output.squeeze()
return output.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _init_weights(self, module):
def forward(self, interaction):
fm_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim]
y = self.sigmoid(self.first_order_linear(interaction) + self.fm(fm_all_embeddings))
return y.squeeze()
return y.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/fnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def forward(self, interaction):

output = self.predict_layer(self.mlp_layers(fnn_all_embeddings.view(batch_size, -1)))
output = self.sigmoid(output)
return output.squeeze()
return output.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/fwfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def forward(self, interaction):

output = self.sigmoid(self.first_order_linear(interaction) + self.fwfm_layer(fwfm_all_embeddings))

return output.squeeze()
return output.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/lr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _init_weights(self, module):

def forward(self, interaction):
output = self.sigmoid(self.first_order_linear(interaction))
return output.squeeze()
return output.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/nfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def forward(self, interaction):

output = self.predict_layer(self.mlp_layers(bn_nfm_all_embeddings)) + self.first_order_linear(interaction)
output = self.sigmoid(output)
return output.squeeze()
return output.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/pnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(self, interaction):

output = self.predict_layer(self.mlp_layers(output)) # [batch_size,1]
output = self.sigmoid(output)
return output.squeeze()
return output.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/widedeep.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def forward(self, interaction):

deep_output = self.deep_predict_layer(self.mlp_layers(widedeep_all_embeddings.view(batch_size, -1)))
output = self.sigmoid(fm_output + deep_output)
return output.squeeze()
return output.squeeze(-1)

def calculate_loss(self, interaction):
label = interaction[self.LABEL]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/convncf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def forward(self, user, item):
cnn_output = cnn_output.sum(axis=(2, 3))

prediction = self.predict_layers(cnn_output)
prediction = prediction.squeeze()
prediction = prediction.squeeze(-1)

return prediction

Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/enmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def forward(self, user):
item_embedding = self.item_embedding(user_inter) # shape: [B, max_len, embedding_size]
score = torch.mul(user_embedding.unsqueeze(1), item_embedding) # shape: [B, max_len, embedding_size]
score = self.H_i(score) # shape: [B,max_len,1]
score = score.squeeze() # shape:[B,max_len]
score = score.squeeze(-1) # shape:[B,max_len]

return score

Expand Down
11 changes: 2 additions & 9 deletions recbole/model/general_recommender/itemknn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,7 @@ def compute_similarity(self, method, block_size=100):
data = self.dataMatrix[start_block:end_block, :]
else:
data = self.dataMatrix[:, start_block:end_block]
data = data.toarray().squeeze()

if data.ndim == 1:
data = np.expand_dims(data, axis=1)
data = data.toarray()

# Compute similarities

Expand All @@ -105,11 +102,7 @@ def compute_similarity(self, method, block_size=100):
this_block_weights = self.dataMatrix.T.dot(data)

for index_in_block in range(this_block_size):

if this_block_size == 1:
this_line_weights = this_block_weights.squeeze()
else:
this_line_weights = this_block_weights[:, index_in_block]
this_line_weights = this_block_weights[:, index_in_block]

Index = index_in_block + start_block
this_line_weights[Index] = 0.0
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/neumf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def forward(self, user, item):
output = self.sigmoid(self.predict_layer(mlp_output))
else:
raise RuntimeError('mf_train and mlp_train can not be False at the same time')
return output.squeeze()
return output.squeeze(-1)

def calculate_loss(self, interaction):
user = interaction[self.USER_ID]
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/general_recommender/pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def calculate_loss(self, interaction):
def predict(self, interaction):
item = interaction[self.ITEM_ID]
result = torch.true_divide(self.item_cnt[item, :], self.max_cnt)
return result.squeeze()
return result.squeeze(-1)

def full_sort_predict(self, interaction):
batch_user_num = interaction[self.USER_ID].shape[0]
Expand Down
6 changes: 3 additions & 3 deletions recbole/model/knowledge_aware_recommender/cke.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def _get_kg_embedding(self, h, r, pos_t, neg_t):
r_e = self.relation_embedding(r)
r_trans_w = self.trans_w(r).view(r.size(0), self.embedding_size, self.kg_embedding_size)

h_e = torch.bmm(h_e, r_trans_w).squeeze()
pos_t_e = torch.bmm(pos_t_e, r_trans_w).squeeze()
neg_t_e = torch.bmm(neg_t_e, r_trans_w).squeeze()
h_e = torch.bmm(h_e, r_trans_w).squeeze(1)
pos_t_e = torch.bmm(pos_t_e, r_trans_w).squeeze(1)
neg_t_e = torch.bmm(neg_t_e, r_trans_w).squeeze(1)

r_e = F.normalize(r_e, p=2, dim=1)
h_e = F.normalize(h_e, p=2, dim=1)
Expand Down
6 changes: 3 additions & 3 deletions recbole/model/knowledge_aware_recommender/kgat.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def _get_kg_embedding(self, h, r, pos_t, neg_t):
r_e = self.relation_embedding(r)
r_trans_w = self.trans_w(r).view(r.size(0), self.embedding_size, self.kg_embedding_size)

h_e = torch.bmm(h_e, r_trans_w).squeeze()
pos_t_e = torch.bmm(pos_t_e, r_trans_w).squeeze()
neg_t_e = torch.bmm(neg_t_e, r_trans_w).squeeze()
h_e = torch.bmm(h_e, r_trans_w).squeeze(1)
pos_t_e = torch.bmm(pos_t_e, r_trans_w).squeeze(1)
neg_t_e = torch.bmm(neg_t_e, r_trans_w).squeeze(1)

return h_e, r_e, pos_t_e, neg_t_e

Expand Down
9 changes: 4 additions & 5 deletions recbole/model/sequential_recommender/dien.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,10 @@ def forward(self, user, item_seq, neg_item_seq, item_seq_len, next_items):
feature_table[type] = feature_table[type].view(table_shape[:-2] + (feat_num * embedding_size,))

user_feat_list = feature_table['user']
item_feat_list, neg_item_feat_list, target_item_feat_emb = feature_table['item'].split([
max_length, max_length, 1
],
dim=1)
target_item_feat_emb = target_item_feat_emb.squeeze()
item_feat_list, neg_item_feat_list, target_item_feat_emb = feature_table['item'].split(
[max_length, max_length, 1], dim=1
)
target_item_feat_emb = target_item_feat_emb.squeeze(1)

# interest
interest, aux_loss = self.interset_extractor(item_feat_list, item_seq_len, neg_item_feat_list)
Expand Down
4 changes: 2 additions & 2 deletions recbole/model/sequential_recommender/din.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,11 @@ def forward(self, user, item_seq, item_seq_len, next_items):

user_feat_list = feature_table['user']
item_feat_list, target_item_feat_emb = feature_table['item'].split([max_length, 1], dim=1)
target_item_feat_emb = target_item_feat_emb.squeeze()
target_item_feat_emb = target_item_feat_emb.squeeze(1)

# attention
user_emb = self.attention(target_item_feat_emb, item_feat_list, item_seq_len)
user_emb = user_emb.squeeze()
user_emb = user_emb.squeeze(1)

# input the DNN to get the prediction score
din_in = torch.cat([user_emb, target_item_feat_emb, user_emb * target_item_feat_emb], dim=-1)
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/sequential_recommender/repeatnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def predict(self, interaction):
# batch_size * num_items
seq_output = seq_output.unsqueeze(-1)
# batch_size * num_items * 1
scores = self.gather_indexes(seq_output, test_item).squeeze()
scores = self.gather_indexes(seq_output, test_item).squeeze(-1)

return scores

Expand Down

0 comments on commit 4b6c6fe

Please sign in to comment.