diff --git a/recbole/model/sequential_recommender/bert4rec.py b/recbole/model/sequential_recommender/bert4rec.py index 0a759d6ff..936e57d77 100644 --- a/recbole/model/sequential_recommender/bert4rec.py +++ b/recbole/model/sequential_recommender/bert4rec.py @@ -3,6 +3,11 @@ # @Author : Hui Wang # @Email : hui.wang@ruc.edu.cn +# UPDATE +# @Time : 2023/9/4 +# @Author : Enze Liu +# @Email : enzeeliu@foxmail.com + r""" BERT4Rec ################################################ @@ -75,6 +80,10 @@ def __init__(self, config, dataset): self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self.dropout = nn.Dropout(self.hidden_dropout_prob) + self.output_ffn = nn.Linear(self.hidden_size, self.hidden_size) + self.output_gelu = nn.GELU() + self.output_ln = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self.output_bias = nn.Parameter(torch.zeros(self.n_items)) # we only need compute the loss at the masked position try: @@ -124,7 +133,9 @@ def forward(self, item_seq): trm_output = self.trm_encoder( input_emb, extended_attention_mask, output_all_encoded_layers=True ) - output = trm_output[-1] + ffn_output = self.output_ffn(trm_output[-1]) + ffn_output = self.output_gelu(ffn_output) + output = self.output_ln(ffn_output) return output # [B L H] def multi_hot_embed(self, masked_index, max_length): @@ -172,8 +183,14 @@ def calculate_loss(self, interaction): if self.loss_type == "BPR": pos_items_emb = self.item_embedding(pos_items) # [B mask_len H] neg_items_emb = self.item_embedding(neg_items) # [B mask_len H] - pos_score = torch.sum(seq_output * pos_items_emb, dim=-1) # [B mask_len] - neg_score = torch.sum(seq_output * neg_items_emb, dim=-1) # [B mask_len] + pos_score = ( + torch.sum(seq_output * pos_items_emb, dim=-1) + + self.output_bias[pos_items] + ) # [B mask_len] + neg_score = ( + torch.sum(seq_output * neg_items_emb, dim=-1) + + self.output_bias[neg_items] + ) # [B mask_len] targets = (masked_index > 0).float() loss = -torch.sum( torch.log(1e-14 + torch.sigmoid(pos_score - neg_score)) * targets @@ -183,8 +200,9 @@ def calculate_loss(self, interaction): elif self.loss_type == "CE": loss_fct = nn.CrossEntropyLoss(reduction="none") test_item_emb = self.item_embedding.weight[: self.n_items] # [item_num H] - logits = torch.matmul( - seq_output, test_item_emb.transpose(0, 1) + logits = ( + torch.matmul(seq_output, test_item_emb.transpose(0, 1)) + + self.output_bias ) # [B mask_len item_num] targets = (masked_index > 0).float().view(-1) # [B*mask_len] @@ -204,7 +222,9 @@ def predict(self, interaction): seq_output = self.forward(item_seq) seq_output = self.gather_indexes(seq_output, item_seq_len - 1) # [B H] test_item_emb = self.item_embedding(test_item) - scores = torch.mul(seq_output, test_item_emb).sum(dim=1) # [B] + scores = (torch.mul(seq_output, test_item_emb)).sum(dim=1) + self.output_bias[ + test_item + ] # [B] return scores def full_sort_predict(self, interaction): @@ -216,7 +236,7 @@ def full_sort_predict(self, interaction): test_items_emb = self.item_embedding.weight[ : self.n_items ] # delete masked token - scores = torch.matmul( - seq_output, test_items_emb.transpose(0, 1) + scores = ( + torch.matmul(seq_output, test_items_emb.transpose(0, 1)) + self.output_bias ) # [B, item_num] return scores