diff --git a/recbole/model/sequential_recommender/__init__.py b/recbole/model/sequential_recommender/__init__.py index 81788524d..e6787cb60 100644 --- a/recbole/model/sequential_recommender/__init__.py +++ b/recbole/model/sequential_recommender/__init__.py @@ -8,6 +8,7 @@ from recbole.model.sequential_recommender.fpmc import FPMC from recbole.model.sequential_recommender.gcsan import GCSAN from recbole.model.sequential_recommender.gru4rec import GRU4Rec +from recbole.model.sequential_recommender.gru4reccpr import GRU4RecCPR from recbole.model.sequential_recommender.gru4recf import GRU4RecF from recbole.model.sequential_recommender.gru4reckg import GRU4RecKG from recbole.model.sequential_recommender.hgn import HGN @@ -20,6 +21,7 @@ from recbole.model.sequential_recommender.repeatnet import RepeatNet from recbole.model.sequential_recommender.s3rec import S3Rec from recbole.model.sequential_recommender.sasrec import SASRec +from recbole.model.sequential_recommender.sasreccpr import SASRecCPR from recbole.model.sequential_recommender.sasrecf import SASRecF from recbole.model.sequential_recommender.shan import SHAN from recbole.model.sequential_recommender.sine import SINE diff --git a/recbole/model/sequential_recommender/gru4reccpr.py b/recbole/model/sequential_recommender/gru4reccpr.py new file mode 100644 index 000000000..0e94cb634 --- /dev/null +++ b/recbole/model/sequential_recommender/gru4reccpr.py @@ -0,0 +1,339 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/8/17 19:38 +# @Author : Yujie Lu +# @Email : yujielu1998@gmail.com + +# UPDATE: +# @Time : 2020/8/19, 2020/10/2 +# @Author : Yupeng Hou, Yujie Lu +# @Email : houyupeng@ruc.edu.cn, yujielu1998@gmail.com + +# UPDATE: +# @Time : 2023/11/24 +# @Author : Haw-Shiuan Chang +# @Email : ken77921@gmail.com + +r""" +GRU4Rec + Softmax-CPR +################################################ + +Reference: + Yong Kiam Tan et al. "Improved Recurrent Neural Networks for Session-based Recommendations." in DLRS 2016. + Haw-Shiuan Chang, Nikhil Agarwal, and Andrew McCallum "To Copy, or not to Copy; That is a Critical Issue of the Output Softmax Layer in Neural Sequential Recommenders" in WSDM 2024 + + +""" + +import torch +from torch import nn +from torch.nn.init import xavier_uniform_, xavier_normal_ + +import torch.nn.functional as F +import math +from recbole.model.abstract_recommender import SequentialRecommender +from recbole.model.loss import BPRLoss +import sys + +def gelu(x): + return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + +class GRU4RecCPR(SequentialRecommender): + r"""GRU4Rec is a model that incorporate RNN for recommendation. + + Note: + + Regarding the innovation of this article,we can only achieve the data augmentation mentioned + in the paper and directly output the embedding of the item, + in order that the generation method we used is common to other sequential models. + """ + + def __init__(self, config, dataset): + super(GRU4RecCPR, self).__init__(config, dataset) + + # load parameters info + self.hidden_size = config['hidden_size'] + self.embedding_size = config['embedding_size'] + self.loss_type = config['loss_type'] + self.num_layers = config['num_layers'] + self.dropout_prob = config['dropout_prob'] + + self.n_facet_all = config['n_facet_all'] #added for mfs + self.n_facet = config['n_facet'] #added for mfs + self.n_facet_window = config['n_facet_window'] #added for mfs + self.n_facet_hidden = min(config['n_facet_hidden'], config['num_layers']) #config['n_facet_hidden'] #added for mfs + self.n_facet_MLP = config['n_facet_MLP'] #added for mfs + self.n_facet_context = config['n_facet_context'] #added for dynamic partioning + self.n_facet_reranker = config['n_facet_reranker'] #added for dynamic partioning + self.n_facet_emb = config['n_facet_emb'] #added for dynamic partioning + assert self.n_facet_MLP <= 0 #-1 or 0 + assert self.n_facet_window <= 0 + self.n_facet_window = - self.n_facet_window + self.n_facet_MLP = - self.n_facet_MLP + self.softmax_nonlinear='None' #added for mfs + self.use_out_emb = config['use_out_emb'] #added for mfs + self.only_compute_loss = True #added for mfs + + self.dense = nn.Linear(self.hidden_size, self.embedding_size) + out_size = self.embedding_size + + self.n_embd = out_size + + self.use_proj_bias = config['use_proj_bias'] #added for mfs + self.weight_mode = config['weight_mode'] #added for mfs + self.context_norm = config['context_norm'] #added for mfs + self.post_remove_context = config['post_remove_context'] #added for mfs + self.reranker_merging_mode = config['reranker_merging_mode'] #added for mfs + self.partition_merging_mode = config['partition_merging_mode'] #added for mfs + self.reranker_CAN_NUM = [int(x) for x in str(config['reranker_CAN_NUM']).split(',')] + assert self.use_proj_bias is not None + self.candidates_from_previous_reranker = True + if self.weight_mode == 'max_logits': + self.n_facet_effective = 1 + else: + self.n_facet_effective = self.n_facet + + assert self.n_facet + self.n_facet_context + self.n_facet_reranker*len(self.reranker_CAN_NUM) + self.n_facet_emb == self.n_facet_all + assert self.n_facet_emb == 0 or self.n_facet_emb == 2 + + + hidden_state_input_ratio = 1 + self.n_facet_MLP #1 + 1 + self.MLP_linear = nn.Linear(self.n_embd * (self.n_facet_hidden * (self.n_facet_window+1) ), self.n_embd * self.n_facet_MLP) # (hid_dim*2) -> (hid_dim) + total_lin_dim = self.n_embd * hidden_state_input_ratio + self.project_arr = nn.ModuleList([nn.Linear(total_lin_dim, self.n_embd, bias=self.use_proj_bias) for i in range(self.n_facet_all)]) + + self.project_emb = nn.Linear(self.n_embd, self.n_embd, bias=self.use_proj_bias) + if len(self.weight_mode) > 0: + self.weight_facet_decoder = nn.Linear(self.n_embd * hidden_state_input_ratio, self.n_facet_effective) + self.weight_global = nn.Parameter( torch.ones(self.n_facet_effective) ) + + self.c = 123 + + # define layers and loss + self.emb_dropout = nn.Dropout(self.dropout_prob) + self.gru_layers = nn.GRU( + input_size=self.embedding_size, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + bias=False, + batch_first=True, + ) + self.item_embedding = nn.Embedding(self.n_items, self.embedding_size, padding_idx=0) + + + if self.use_out_emb: + self.out_item_embedding = nn.Linear(out_size, self.n_items, bias = False) + else: + self.out_item_embedding = self.item_embedding + self.out_item_embedding.bias = None + + if self.loss_type == 'BPR': + print("current softmax-cpr code does not support BPR loss") + sys.exit(0) + elif self.loss_type == 'CE': + self.loss_fct = nn.CrossEntropyLoss() + else: + raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") + + # parameters initialization + self.apply(self._init_weights) + + def _init_weights(self, module): + if isinstance(module, nn.Embedding): + xavier_normal_(module.weight) + elif isinstance(module, nn.GRU): + xavier_uniform_(module.weight_hh_l0) + xavier_uniform_(module.weight_ih_l0) + + def forward(self, item_seq, item_seq_len): + item_seq_emb = self.item_embedding(item_seq) + item_seq_emb_dropout = self.emb_dropout(item_seq_emb) + gru_output, _ = self.gru_layers(item_seq_emb_dropout) + gru_output = self.dense(gru_output) + return gru_output + + def get_facet_emb(self,input_emb, i): + return self.project_arr[i](input_emb) + + def calculate_loss_prob(self, interaction, only_compute_prob=False): + item_seq = interaction[self.ITEM_SEQ] + item_seq_len = interaction[self.ITEM_SEQ_LEN] + last_layer_hs = self.forward(item_seq, item_seq_len) + all_hidden_states = [last_layer_hs] + if self.loss_type != 'CE': + print("current softmax-cpr code does not support BPR or the losses other than cross entropy") + sys.exit(0) + else: # self.loss_type = 'CE' + test_item_emb = self.out_item_embedding.weight + test_item_bias = self.out_item_embedding.bias + + '''mfs code starts''' + device = all_hidden_states[0].device + + hidden_emb_arr = [] + # h_facet_hidden -> H, n_face_window -> W, here 1 and 0 + for i in range(self.n_facet_hidden): + #print('all_hidden_states length is {}. i is {}'.format(len(all_hidden_states), i)) + hidden_states = all_hidden_states[-(i+1)] #i-th hidden-state embedding from the top + device = hidden_states.device + hidden_emb_arr.append(hidden_states) + for j in range(self.n_facet_window): + bsz, seq_len, hidden_size = hidden_states.size() #bsz -> , seq_len -> , hidden_size -> 768 in GPT-small? + if j+1 < hidden_states.size(1): + shifted_hidden = torch.cat( (torch.zeros( (bsz, (j+1), hidden_size), device = device), hidden_states[:,:-(j+1),:]), dim = 1) + else: + shifted_hidden = torch.zeros( (bsz, hidden_states.size(1), hidden_size), device = device) + hidden_emb_arr.append(shifted_hidden) + #hidden_emb_arr -> (W*H, bsz, seq_len, hidden_size) + + + #n_facet_MLP -> 1 + if self.n_facet_MLP > 0: + stacked_hidden_emb_raw_arr = torch.cat(hidden_emb_arr, dim=-1) #(bsz, seq_len, W*H*hidden_size) + hidden_emb_MLP = self.MLP_linear(stacked_hidden_emb_raw_arr) #bsz, seq_len, hidden_size + stacked_hidden_emb_arr_raw = torch.cat([hidden_emb_arr[0], gelu(hidden_emb_MLP)], dim=-1) #bsz, seq_len, 2*hidden_size + else: + stacked_hidden_emb_arr_raw = hidden_emb_arr[0] + + #Only use the hidden state corresponding to the last word + stacked_hidden_emb_arr = stacked_hidden_emb_arr_raw[:,-1,:].unsqueeze(dim=1) + + #list of linear projects per facet + projected_emb_arr = [] + #list of final logits per facet + facet_lm_logits_arr = [] + + rereanker_candidate_token_ids_arr = [] + for i in range(self.n_facet): + projected_emb = self.get_facet_emb(stacked_hidden_emb_arr, i) #(bsz, seq_len, hidden_dim) + projected_emb_arr.append(projected_emb) + lm_logits = F.linear(projected_emb, test_item_emb, test_item_bias) + facet_lm_logits_arr.append(lm_logits) + if i < self.n_facet_reranker and not self.candidates_from_previous_reranker: + candidate_token_ids = [] + for j in range(len(self.reranker_CAN_NUM)): + _, candidate_token_ids_ = torch.topk(lm_logits, self.reranker_CAN_NUM[j]) + candidate_token_ids.append(candidate_token_ids_) + rereanker_candidate_token_ids_arr.append(candidate_token_ids) + + for i in range(self.n_facet_reranker): + for j in range(len(self.reranker_CAN_NUM)): + projected_emb = self.get_facet_emb(stacked_hidden_emb_arr, self.n_facet+i*len(self.reranker_CAN_NUM) + j) #(bsz, seq_len, hidden_dim) + projected_emb_arr.append(projected_emb) + + for i in range(self.n_facet_context): + projected_emb = self.get_facet_emb(stacked_hidden_emb_arr, self.n_facet+self.n_facet_reranker*len(self.reranker_CAN_NUM)+i) #(bsz, seq_len, hidden_dim) + projected_emb_arr.append(projected_emb) + + #to generate context-based embeddings for words in input + for i in range(self.n_facet_emb): + projected_emb = self.get_facet_emb(stacked_hidden_emb_arr_raw, self.n_facet + self.n_facet_context + self.n_facet_reranker*len(self.reranker_CAN_NUM) + i) #(bsz, seq_len, hidden_dim) + projected_emb_arr.append(projected_emb) + + for i in range(self.n_facet_reranker): + bsz, seq_len, hidden_size = projected_emb_arr[i].size() + for j in range(len(self.reranker_CAN_NUM)): + if self.candidates_from_previous_reranker: + _, candidate_token_ids = torch.topk(facet_lm_logits_arr[i], self.reranker_CAN_NUM[j]) #(bsz, seq_len, topk) + else: + candidate_token_ids = rereanker_candidate_token_ids_arr[i][j] + logit_hidden_reranker_topn = (projected_emb_arr[self.n_facet + i*len(self.reranker_CAN_NUM) + j].unsqueeze(dim=2).expand(bsz, seq_len, self.reranker_CAN_NUM[j], hidden_size) * test_item_emb[candidate_token_ids, :] ).sum(dim=-1) #(bsz, seq_len, emb_size) x (bsz, seq_len, topk, emb_size) -> (bsz, seq_len, topk) + if test_item_bias is not None: + logit_hidden_reranker_topn +=test_item_bias[candidate_token_ids] + if self.reranker_merging_mode == 'add': + #print("inside reranker") + facet_lm_logits_arr[i].scatter_add_(2, candidate_token_ids, logit_hidden_reranker_topn) #(bsz, seq_len, vocab_size) <- (bsz, seq_len, topk) x (bsz, seq_len, topk) + else: + facet_lm_logits_arr[i].scatter_(2, candidate_token_ids, logit_hidden_reranker_topn) #(bsz, seq_len, vocab_size) <- (bsz, seq_len, topk) x (bsz, seq_len, topk) + + for i in range(self.n_facet_context): + bsz, seq_len_1, hidden_size = projected_emb_arr[i].size() + bsz, seq_len_2 = item_seq.size() + logit_hidden_context = (projected_emb_arr[self.n_facet + self.n_facet_reranker*len(self.reranker_CAN_NUM) + i].unsqueeze(dim=2).expand(-1,-1,seq_len_2,-1) * test_item_emb[item_seq, :].unsqueeze(dim=1).expand(-1,seq_len_1,-1,-1) ).sum(dim=-1) + if test_item_bias is not None: + logit_hidden_context += test_item_bias[item_seq].unsqueeze(dim=1).expand(-1,seq_len_1,-1) + logit_hidden_pointer = 0 + if self.n_facet_emb == 2: + logit_hidden_pointer = ( projected_emb_arr[-2][:,-1,:].unsqueeze(dim=1).unsqueeze(dim=1).expand(-1,seq_len_1,seq_len_2,-1) * projected_emb_arr[-1].unsqueeze(dim=1).expand(-1,seq_len_1,-1,-1) ).sum(dim=-1) + + item_seq_expand = item_seq.unsqueeze(dim=1).expand(-1,seq_len_1,-1) + only_new_logits = torch.zeros_like(facet_lm_logits_arr[i]) + if self.context_norm: + only_new_logits.scatter_add_(dim=2, index=item_seq_expand, src=logit_hidden_context+logit_hidden_pointer) + item_count = torch.zeros_like(only_new_logits) + 1e-15 + item_count.scatter_add_(dim=2, index=item_seq_expand,src=torch.ones_like(item_seq_expand).to(dtype=item_count.dtype)) + only_new_logits = only_new_logits / item_count + else: + only_new_logits.scatter_add_(dim=2, index=item_seq_expand, src=logit_hidden_context) + item_count = torch.zeros_like(only_new_logits) + 1e-15 + item_count.scatter_add_(dim=2, index=item_seq_expand,src=torch.ones_like(item_seq_expand).to(dtype=item_count.dtype)) + only_new_logits = only_new_logits / item_count + only_new_logits.scatter_add_(dim=2, index=item_seq_expand, src=logit_hidden_pointer) + + if self.partition_merging_mode == 'replace': + facet_lm_logits_arr[i].scatter_(dim=2, index=item_seq_expand, src=torch.zeros_like(item_seq_expand).to(dtype=facet_lm_logits_arr[i].dtype) ) + facet_lm_logits_arr[i] = facet_lm_logits_arr[i] + only_new_logits + elif self.partition_merging_mode == 'add': + facet_lm_logits_arr[i] = facet_lm_logits_arr[i] + only_new_logits + elif self.partition_merging_mode == 'half': + item_in_context = torch.ones_like(only_new_logits) + item_in_context.scatter_(dim=2, index=item_seq_expand,src= 2 * torch.ones_like(item_seq_expand).to(dtype=item_count.dtype)) + facet_lm_logits_arr[i] = facet_lm_logits_arr[i] / item_in_context + only_new_logits + + + weight = None + if self.weight_mode == 'dynamic': + weight = self.weight_facet_decoder(stacked_hidden_emb_arr).softmax(dim=-1) #hidden_dim*hidden_input_state_ration -> n_facet_effective + elif self.weight_mode == 'static': + weight = self.weight_global.softmax(dim=-1) #torch.ones(n_facet_effective) + elif self.weight_mode == 'max_logits': + stacked_facet_lm_logits = torch.stack(facet_lm_logits_arr, dim=0) + facet_lm_logits_arr = [stacked_facet_lm_logits.amax(dim=0)] + + prediction_prob = 0 + + for i in range(self.n_facet_effective): + facet_lm_logits = facet_lm_logits_arr[i] + if self.softmax_nonlinear == 'sigsoftmax': #'None' here + facet_lm_logits_sig = torch.exp(facet_lm_logits - facet_lm_logits.max(dim=-1,keepdim=True)[0]) * (1e-20 + torch.sigmoid(facet_lm_logits)) + facet_lm_logits_softmax = facet_lm_logits_sig / facet_lm_logits_sig.sum(dim=-1,keepdim=True) + elif self.softmax_nonlinear == 'None': + facet_lm_logits_softmax = facet_lm_logits.softmax(dim=-1) #softmax over final logits + if self.weight_mode == 'dynamic': + prediction_prob += facet_lm_logits_softmax * weight[:,:,i].unsqueeze(-1) + elif self.weight_mode == 'static': + prediction_prob += facet_lm_logits_softmax * weight[i] + else: + prediction_prob += facet_lm_logits_softmax / self.n_facet_effective #softmax over final logits/1 + if not only_compute_prob: + inp = torch.log(prediction_prob.view(-1, self.n_items)+1e-8) + pos_items = interaction[self.POS_ITEM_ID] + loss_raw = self.loss_fct(inp, pos_items.view(-1)) + loss = loss_raw.mean() + else: + loss = None + return loss, prediction_prob.squeeze(dim=1) + + def calculate_loss(self, interaction): + loss, prediction_prob = self.calculate_loss_prob(interaction) + return loss + + def predict(self, interaction): + print("Current softmax cpr code does not support negative sampling in an efficient way just like RepeatNet.", file=sys.stderr) + assert False #If you can accept slow running time, uncomment this line. + loss, prediction_prob = self.calculate_loss_prob(interaction, only_compute_prob=True) + if self.post_remove_context: + item_seq = interaction[self.ITEM_SEQ] + prediction_prob.scatter_(1, item_seq, 0) + test_item = interaction[self.ITEM_ID] + prediction_prob = prediction_prob.unsqueeze(-1) + # batch_size * num_items * 1 + scores = self.gather_indexes(prediction_prob, test_item).squeeze(-1) + return scores + + def full_sort_predict(self, interaction): + loss, prediction_prob = self.calculate_loss_prob(interaction) + item_seq = interaction[self.ITEM_SEQ] + if self.post_remove_context: + prediction_prob.scatter_(1, item_seq, 0) + return prediction_prob + diff --git a/recbole/model/sequential_recommender/sasreccpr.py b/recbole/model/sequential_recommender/sasreccpr.py new file mode 100644 index 000000000..4c0223c5a --- /dev/null +++ b/recbole/model/sequential_recommender/sasreccpr.py @@ -0,0 +1,335 @@ +# -*- coding: utf-8 -*- +# @Time : 2020/9/18 11:33 +# @Author : Hui Wang +# @Email : hui.wang@ruc.edu.cn + +# UPDATE: +# @Time : 2023/11/24 +# @Author : Haw-Shiuan Chang +# @Email : ken77921@gmail.com + +""" +SASRec + Softmax-CPR +################################################ + +Reference: + Wang-Cheng Kang et al. "Self-Attentive Sequential Recommendation." in ICDM 2018. + Haw-Shiuan Chang, Nikhil Agarwal, and Andrew McCallum "To Copy, or not to Copy; That is a Critical Issue of the Output Softmax Layer in Neural Sequential Recommenders" in WSDM 2024 + +Reference: + https://github.com/kang205/SASRec + https://arxiv.org/pdf/2310.14079.pdf + +""" + +import sys +import torch +from torch import nn +import torch.nn.functional as F +from recbole.model.abstract_recommender import SequentialRecommender +from recbole.model.layers import TransformerEncoder +#from recbole.model.loss import BPRLoss +import math + +def gelu(x): + return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + +class SASRecCPR(SequentialRecommender): + r""" + SASRec is the first sequential recommender based on self-attentive mechanism. + + NOTE: + In the author's implementation, the Point-Wise Feed-Forward Network (PFFN) is implemented + by CNN with 1x1 kernel. In this implementation, we follows the original BERT implementation + using Fully Connected Layer to implement the PFFN. + """ + + def __init__(self, config, dataset): + super(SASRecCPR, self).__init__(config, dataset) + + # load parameters info + self.n_layers = config['n_layers'] + self.n_heads = config['n_heads'] + self.hidden_size = config['hidden_size'] # same as embedding_size + self.inner_size = config['inner_size'] # the dimensionality in feed-forward layer + self.hidden_dropout_prob = config['hidden_dropout_prob'] + self.attn_dropout_prob = config['attn_dropout_prob'] + self.hidden_act = config['hidden_act'] + self.layer_norm_eps = config['layer_norm_eps'] + self.initializer_range = config['initializer_range'] + self.loss_type = config['loss_type'] + self.n_facet_all = config['n_facet_all'] #added for mfs + self.n_facet = config['n_facet'] #added for mfs + self.n_facet_window = config['n_facet_window'] #added for mfs + self.n_facet_hidden = min(config['n_facet_hidden'], config['n_layers']) #added for mfs + self.n_facet_MLP = config['n_facet_MLP'] #added for mfs + self.n_facet_context = config['n_facet_context'] #added for dynamic partioning + self.n_facet_reranker = config['n_facet_reranker'] #added for dynamic partioning + self.n_facet_emb = config['n_facet_emb'] #added for dynamic partioning + self.weight_mode = config['weight_mode'] #added for mfs + self.context_norm = config['context_norm'] #added for mfs + self.post_remove_context = config['post_remove_context'] #added for mfs + self.partition_merging_mode = config['partition_merging_mode'] #added for mfs + self.reranker_merging_mode = config['reranker_merging_mode'] #added for mfs + self.reranker_CAN_NUM = [int(x) for x in str(config['reranker_CAN_NUM']).split(',')] + self.candidates_from_previous_reranker = True + if self.weight_mode == 'max_logits': + self.n_facet_effective = 1 + else: + self.n_facet_effective = self.n_facet + + assert self.n_facet + self.n_facet_context + self.n_facet_reranker*len(self.reranker_CAN_NUM) + self.n_facet_emb == self.n_facet_all + assert self.n_facet_emb == 0 or self.n_facet_emb == 2 + assert self.n_facet_MLP <= 0 #-1 or 0 + assert self.n_facet_window <= 0 + self.n_facet_window = - self.n_facet_window + self.n_facet_MLP = - self.n_facet_MLP + self.softmax_nonlinear='None' #added for mfs + self.use_proj_bias = config['use_proj_bias'] #added for mfs + hidden_state_input_ratio = 1 + self.n_facet_MLP #1 + 1 + self.MLP_linear = nn.Linear(self.hidden_size * (self.n_facet_hidden * (self.n_facet_window+1) ), self.hidden_size * self.n_facet_MLP) # (hid_dim*2) -> (hid_dim) + total_lin_dim = self.hidden_size * hidden_state_input_ratio + self.project_arr = nn.ModuleList([nn.Linear(total_lin_dim, self.hidden_size, bias=self.use_proj_bias) for i in range(self.n_facet_all)]) + + self.project_emb = nn.Linear(self.hidden_size, self.hidden_size, bias=self.use_proj_bias) + if len(self.weight_mode) > 0: + self.weight_facet_decoder = nn.Linear(self.hidden_size * hidden_state_input_ratio, self.n_facet_effective) + self.weight_global = nn.Parameter( torch.ones(self.n_facet_effective) ) + self.output_probs = True + self.item_embedding = nn.Embedding(self.n_items, self.hidden_size, padding_idx=0) + self.position_embedding = nn.Embedding(self.max_seq_length, self.hidden_size) + self.trm_encoder = TransformerEncoder( + n_layers=self.n_layers, + n_heads=self.n_heads, + hidden_size=self.hidden_size, + inner_size=self.inner_size, + hidden_dropout_prob=self.hidden_dropout_prob, + attn_dropout_prob=self.attn_dropout_prob, + hidden_act=self.hidden_act, + layer_norm_eps=self.layer_norm_eps + ) + + self.LayerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) + self.dropout = nn.Dropout(self.hidden_dropout_prob) + + if self.loss_type == 'BPR': + print("current softmax-cpr code does not support BPR loss") + sys.exit(0) + elif self.loss_type == 'CE': + self.loss_fct = nn.NLLLoss(reduction='none', ignore_index=0) #modified for mfs + else: + raise NotImplementedError("Make sure 'loss_type' in ['BPR', 'CE']!") + + # parameters initialization + self.apply(self._init_weights) + + small_value = 0.0001 + + def get_facet_emb(self,input_emb, i): + return self.project_arr[i](input_emb) + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def forward(self, item_seq, item_seq_len): + position_ids = torch.arange(item_seq.size(1), dtype=torch.long, device=item_seq.device) + position_ids = position_ids.unsqueeze(0).expand_as(item_seq) + position_embedding = self.position_embedding(position_ids) + + item_emb = self.item_embedding(item_seq) + input_emb = item_emb + position_embedding + input_emb = self.LayerNorm(input_emb) + input_emb = self.dropout(input_emb) + + extended_attention_mask = self.get_attention_mask(item_seq) + + trm_output = self.trm_encoder(input_emb, extended_attention_mask, output_all_encoded_layers=True) + return trm_output + + + def calculate_loss_prob(self, interaction, only_compute_prob=False): + item_seq = interaction[self.ITEM_SEQ] + item_seq_len = interaction[self.ITEM_SEQ_LEN] + all_hidden_states = self.forward(item_seq, item_seq_len) + if self.loss_type != 'CE': + print("current softmax-cpr code does not support BPR or the losses other than cross entropy") + sys.exit(0) + else: # self.loss_type = 'CE' + test_item_emb = self.item_embedding.weight + '''mfs code starts''' + device = all_hidden_states[0].device + #check seq_len from hidden size + + ## Multi-input hidden states: generate q_ct from hidden states + #list of hidden state embeddings taken as input + hidden_emb_arr = [] + # h_facet_hidden -> H, n_face_window -> W, here 1 and 0 + for i in range(self.n_facet_hidden): + hidden_states = all_hidden_states[-(i+1)] #i-th hidden-state embedding from the top + device = hidden_states.device + hidden_emb_arr.append(hidden_states) + for j in range(self.n_facet_window): + bsz, seq_len, hidden_size = hidden_states.size() #bsz -> , seq_len -> , hidden_size -> 768 in GPT-small? + if j+1 < hidden_states.size(1): + shifted_hidden = torch.cat( (torch.zeros( (bsz, (j+1), hidden_size), device = device), hidden_states[:,:-(j+1),:]), dim = 1) + else: + shifted_hidden = torch.zeros( (bsz, hidden_states.size(1), hidden_size), device = device) + hidden_emb_arr.append(shifted_hidden) + #hidden_emb_arr -> (W*H, bsz, seq_len, hidden_size) + + + #n_facet_MLP -> 1 + if self.n_facet_MLP > 0: + stacked_hidden_emb_raw_arr = torch.cat(hidden_emb_arr, dim=-1) #(bsz, seq_len, W*H*hidden_size) + # self.MLP_linear = nn.Linear(config.hidden_size * (n_facet_hidden * (n_facet_window+1) ), config.hidden_size * n_facet_MLP) -> why +1? + hidden_emb_MLP = self.MLP_linear(stacked_hidden_emb_raw_arr) #bsz, seq_len, hidden_size + stacked_hidden_emb_arr_raw = torch.cat([hidden_emb_arr[0], gelu(hidden_emb_MLP)], dim=-1) #bsz, seq_len, 2*hidden_size + else: + stacked_hidden_emb_arr_raw = hidden_emb_arr[0] + + #Only use the hidden state corresponding to the last item + #The seq_len = 1 in the following code + stacked_hidden_emb_arr = stacked_hidden_emb_arr_raw[:,-1,:].unsqueeze(dim=1) + + #list of linear projects per facet + projected_emb_arr = [] + #list of final logits per facet + facet_lm_logits_arr = [] + facet_lm_logits_real_arr = [] + + #logits for orig facets + rereanker_candidate_token_ids_arr = [] + for i in range(self.n_facet): + # #linear projection + projected_emb = self.get_facet_emb(stacked_hidden_emb_arr, i) #(bsz, seq_len, hidden_dim) + projected_emb_arr.append(projected_emb) + #logits for all tokens in vocab + lm_logits = F.linear(projected_emb, self.item_embedding.weight, None) + facet_lm_logits_arr.append(lm_logits) + if i < self.n_facet_reranker and not self.candidates_from_previous_reranker: + candidate_token_ids = [] + for j in range(len(self.reranker_CAN_NUM)): + _, candidate_token_ids_ = torch.topk(lm_logits, self.reranker_CAN_NUM[j]) + candidate_token_ids.append(candidate_token_ids_) + rereanker_candidate_token_ids_arr.append(candidate_token_ids) + + for i in range(self.n_facet_reranker): + for j in range(len(self.reranker_CAN_NUM)): + projected_emb = self.get_facet_emb(stacked_hidden_emb_arr, self.n_facet+i*len(self.reranker_CAN_NUM) + j) #(bsz, seq_len, hidden_dim) + projected_emb_arr.append(projected_emb) + + for i in range(self.n_facet_context): + projected_emb = self.get_facet_emb(stacked_hidden_emb_arr, self.n_facet+self.n_facet_reranker*len(self.reranker_CAN_NUM)+i) #(bsz, seq_len, hidden_dim) + projected_emb_arr.append(projected_emb) + + #to generate context-based embeddings for words in input + for i in range(self.n_facet_emb): + projected_emb = self.get_facet_emb(stacked_hidden_emb_arr_raw, self.n_facet + self.n_facet_context + self.n_facet_reranker*len(self.reranker_CAN_NUM) + i) #(bsz, seq_len, hidden_dim) + projected_emb_arr.append(projected_emb) + + for i in range(self.n_facet_reranker): + bsz, seq_len, hidden_size = projected_emb_arr[i].size() + for j in range(len(self.reranker_CAN_NUM)): + if self.candidates_from_previous_reranker: + _, candidate_token_ids = torch.topk(facet_lm_logits_arr[i], self.reranker_CAN_NUM[j]) #(bsz, seq_len, topk) + else: + candidate_token_ids = rereanker_candidate_token_ids_arr[i][j] + logit_hidden_reranker_topn = (projected_emb_arr[self.n_facet + i*len(self.reranker_CAN_NUM) + j].unsqueeze(dim=2).expand(bsz, seq_len, self.reranker_CAN_NUM[j], hidden_size) * self.item_embedding.weight[candidate_token_ids, :] ).sum(dim=-1) #(bsz, seq_len, emb_size) x (bsz, seq_len, topk, emb_size) -> (bsz, seq_len, topk) + if self.reranker_merging_mode == 'add': + facet_lm_logits_arr[i].scatter_add_(2, candidate_token_ids, logit_hidden_reranker_topn) #(bsz, seq_len, vocab_size) <- (bsz, seq_len, topk) x (bsz, seq_len, topk) + else: + facet_lm_logits_arr[i].scatter_(2, candidate_token_ids, logit_hidden_reranker_topn) #(bsz, seq_len, vocab_size) <- (bsz, seq_len, topk) x (bsz, seq_len, topk) + + for i in range(self.n_facet_context): + bsz, seq_len_1, hidden_size = projected_emb_arr[i].size() + bsz, seq_len_2 = item_seq.size() + logit_hidden_context = (projected_emb_arr[self.n_facet + self.n_facet_reranker*len(self.reranker_CAN_NUM) + i].unsqueeze(dim=2).expand(-1,-1,seq_len_2,-1) * self.item_embedding.weight[item_seq, :].unsqueeze(dim=1).expand(-1,seq_len_1,-1,-1) ).sum(dim=-1) + logit_hidden_pointer = 0 + if self.n_facet_emb == 2: + logit_hidden_pointer = ( projected_emb_arr[-2][:,-1,:].unsqueeze(dim=1).unsqueeze(dim=1).expand(-1,seq_len_1,seq_len_2,-1) * projected_emb_arr[-1].unsqueeze(dim=1).expand(-1,seq_len_1,-1,-1) ).sum(dim=-1) + + item_seq_expand = item_seq.unsqueeze(dim=1).expand(-1,seq_len_1,-1) + only_new_logits = torch.zeros_like(facet_lm_logits_arr[i]) + if self.context_norm: + only_new_logits.scatter_add_(dim=2, index=item_seq_expand, src=logit_hidden_context + logit_hidden_pointer) + item_count = torch.zeros_like(only_new_logits) + 1e-15 + item_count.scatter_add_(dim=2, index=item_seq_expand,src=torch.ones_like(item_seq_expand).to(dtype=item_count.dtype)) + only_new_logits = only_new_logits / item_count + else: + only_new_logits.scatter_add_(dim=2, index=item_seq_expand, src=logit_hidden_context) + item_count = torch.zeros_like(only_new_logits) + 1e-15 + item_count.scatter_add_(dim=2, index=item_seq_expand,src=torch.ones_like(item_seq_expand).to(dtype=item_count.dtype)) + only_new_logits = only_new_logits / item_count + only_new_logits.scatter_add_(dim=2, index=item_seq_expand, src=logit_hidden_pointer) + + if self.partition_merging_mode == 'replace': + facet_lm_logits_arr[i].scatter_(dim=2, index=item_seq_expand, src=torch.zeros_like(item_seq_expand).to(dtype=facet_lm_logits_arr[i].dtype) ) + facet_lm_logits_arr[i] = facet_lm_logits_arr[i] + only_new_logits + + weight = None + if self.weight_mode == 'dynamic': + weight = self.weight_facet_decoder(stacked_hidden_emb_arr).softmax(dim=-1) #hidden_dim*hidden_input_state_ration -> n_facet_effective + elif self.weight_mode == 'static': + weight = self.weight_global.softmax(dim=-1) #torch.ones(n_facet_effective) + elif self.weight_mode == 'max_logits': + stacked_facet_lm_logits = torch.stack(facet_lm_logits_arr, dim=0) + facet_lm_logits_arr = [stacked_facet_lm_logits.amax(dim=0)] + + prediction_prob = 0 + + for i in range(self.n_facet_effective): + facet_lm_logits = facet_lm_logits_arr[i] + if self.softmax_nonlinear == 'sigsoftmax': #'None' here + facet_lm_logits_sig = torch.exp(facet_lm_logits - facet_lm_logits.max(dim=-1,keepdim=True)[0]) * (1e-20 + torch.sigmoid(facet_lm_logits)) + facet_lm_logits_softmax = facet_lm_logits_sig / facet_lm_logits_sig.sum(dim=-1,keepdim=True) + elif self.softmax_nonlinear == 'None': + facet_lm_logits_softmax = facet_lm_logits.softmax(dim=-1) #softmax over final logits + if self.weight_mode == 'dynamic': + prediction_prob += facet_lm_logits_softmax * weight[:,:,i].unsqueeze(-1) + elif self.weight_mode == 'static': + prediction_prob += facet_lm_logits_softmax * weight[i] + else: + prediction_prob += facet_lm_logits_softmax / self.n_facet_effective #softmax over final logits/1 + if not only_compute_prob: + inp = torch.log(prediction_prob.view(-1, self.n_items)+1e-8) + pos_items = interaction[self.POS_ITEM_ID] + loss_raw = self.loss_fct(inp, pos_items.view(-1)) + loss = loss_raw.mean() + else: + loss = None + #return loss, prediction_prob.squeeze() + return loss, prediction_prob.squeeze(dim=1) + + def calculate_loss(self, interaction): + loss, prediction_prob = self.calculate_loss_prob(interaction) + return loss + + def predict(self, interaction): + print("Current softmax cpr code does not support negative sampling in an efficient way just like RepeatNet.", file=sys.stderr) + assert False #If you can accept slow running time, comment this line + loss, prediction_prob = self.calculate_loss_prob(interaction, only_compute_prob=True) + if self.post_remove_context: + item_seq = interaction[self.ITEM_SEQ] + prediction_prob.scatter_(1, item_seq, 0) + test_item = interaction[self.ITEM_ID] + prediction_prob = prediction_prob.unsqueeze(-1) + # batch_size * num_items * 1 + scores = self.gather_indexes(prediction_prob, test_item).squeeze(-1) + + return scores + + def full_sort_predict(self, interaction): + loss, prediction_prob = self.calculate_loss_prob(interaction) + if self.post_remove_context: + item_seq = interaction[self.ITEM_SEQ] + prediction_prob.scatter_(1, item_seq, 0) + return prediction_prob diff --git a/recbole/properties/model/GRU4RecCPR.yaml b/recbole/properties/model/GRU4RecCPR.yaml new file mode 100644 index 000000000..5d6a3af3d --- /dev/null +++ b/recbole/properties/model/GRU4RecCPR.yaml @@ -0,0 +1,28 @@ +embedding_size: 64 # (int) The embedding size of items. +hidden_size: 128 # (int) The number of features in the hidden state. +num_layers: 1 # (int) The number of layers in GRU. +dropout_prob: 0 # (float) The dropout rate. +loss_type: 'CE' # (str) The type of loss function. This value can only be 'CE'. + +#Please see https://github.com/iesl/softmax_CPR_recommend/blob/master/run_hyper_loop.sh and [1] to see some common configuration of the following hyperparameters +use_out_emb: False # (bool) If False, we share the output item embedding and input item embedding ([2] shows that the sharing can encourage the item repetition) +n_facet_all: 5 # (int) Number of linear layers for context partition, reranker partition, pointer network, and most items in the vocabulary. Notice that n_facet_all = n_facet + n_facet_context + n_facet_reranker*len(reranker_CAN_NUM_arr) + n_facet_emb +n_facet: 1 # (int) Number of the output hidden states for most items in the vocabulary. If n_facet > 1, we will use mixture of softmax (MoS) +n_facet_context: 1 # (int) Number of the output hidden states for the context partition. This number should be either 0, 1 or n_facet (If you use MoS). +n_facet_reranker: 1 # (int) Number of the output hidden states for a single reranker partition. This number should be either 0, 1 or n_facet (If you use MoS). +reranker_CAN_NUM: 100 # (str) The size of reranker partitions. If you want to use 3 reranker partitions with size 500, 100, and 20, set "500,100,20". Notice that the number should have a descent order (e.g., setting it to 20,100,500 is incorrect). +n_facet_emb: 2 # (int) Number of the output hidden states for pointer network. This number should be either 0 or 2. +n_facet_hidden: 1 # (int) min(n_facet_hidden, num_layers) = H hyperparameter in multiple input hidden states (Mi) [3]. If not using Mi, set this number to 1. +n_facet_window: -2 # (int) -n_facet_window + 1 is the W hyperparameter in multiple input hidden states [3]. If not using Mi, set this number to 0. +n_facet_MLP: -1 # (int) The dimension of q_ct in [3] is (-n_facet_MLP + 1)*embedding_size. If not using Mi, set this number to 0. +weight_mode: '' # (str) The method of merging probability distribution in MoS. The value could be "dynamic" [4], "static", and "max_logits" [1]. +context_norm: 1 # (int) If setting 0, we remove the denominator in Equation (5) of [1]. +partition_merging_mode: 'replace' # (str) If "replace", the logit from context partition and pointer network would overwrite the logit from reranker partition and original softmax. Otherwise, the logit would be added. +reranker_merging_mode: 'replace' # (str) If "add", the logit from reranker partition would be added with the original softmax. Otherwise, the softmax logit would be replaced by the logit from reranker partition. +use_proj_bias: 1 # (bool) In linear layers for all output hidden states, if we want to use the bias term. +post_remove_context: 0 # (int) Setting the probability of all the items in the history to be 0 [2]. + +#[1] Haw-Shiuan Chang, Nikhil Agarwal, and Andrew McCallum. "To Copy, or not to Copy; That is a Critical Issue of the Output Softmax Layer in Neural Sequential Recommenders." In Proceedings of The 17th ACM Inernational Conference on Web Search and Data Mining (WSDM 24) +#[2] Ming Li, Ali Vardasbi, Andrew Yates, and Maarten de Rijke. 2023. Repetition and Exploration in Sequential Recommendation. In SIGIR 2023: 46th international ACM SIGIR Conference on Research and Development in Information Retrieval. ACM, 2532–2541. +#[3] Haw-Shiuan Chang and Andrew McCallum. 2022. Softmax bottleneck makes language models unable to represent multi-mode word distributions. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 8048–8073 +#[4] Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, and William W. Cohen. "Breaking the Softmax Bottleneck: A High-Rank RNN Language Model." In International Conference on Learning Representations. 2018. diff --git a/recbole/properties/model/SASRecCPR.yaml b/recbole/properties/model/SASRecCPR.yaml new file mode 100644 index 000000000..9d249e0bc --- /dev/null +++ b/recbole/properties/model/SASRecCPR.yaml @@ -0,0 +1,32 @@ +n_layers: 2 # (int) The number of transformer layers in transformer encoder. +n_heads: 2 # (int) The number of attention heads for multi-head attention layer. +hidden_size: 64 # (int) The number of features in the hidden state. +inner_size: 256 # (int) The inner hidden size in feed-forward layer. +hidden_dropout_prob: 0 # (float) The probability of an element to be zeroed. +attn_dropout_prob: 0.1 # (float) The probability of an attention score to be zeroed. +hidden_act: 'gelu' # (str) The activation function in feed-forward layer. +layer_norm_eps: 1e-12 # (float) A value added to the denominator for numerical stability. +initializer_range: 0.02 # (float) The standard deviation for normal initialization. +loss_type: 'CE' # (str) The type of loss function. This value can only be 'CE'. + +#Please see https://github.com/iesl/softmax_CPR_recommend/blob/master/run_hyper_loop.sh and [1] to see some common configuration of the following hyperparameters +n_facet_all: 5 # (int) Number of linear layers for context partition, reranker partition, pointer network, and most items in the vocabulary. Notice that n_facet_all = n_facet + n_facet_context + n_facet_reranker*len(reranker_CAN_NUM_arr) + n_facet_emb +n_facet: 1 # (int) Number of the output hidden states for most items in the vocabulary. If n_facet > 1, we will use mixture of softmax (MoS) +n_facet_context: 1 # (int) Number of the output hidden states for the context partition. This number should be either 0, 1 or n_facet (If you use MoS). +n_facet_reranker: 1 # (int) Number of the output hidden states for a single reranker partition. This number should be either 0, 1 or n_facet (If you use MoS). +reranker_CAN_NUM: 100 # (str) The size of reranker partitions. If you want to use 3 reranker partitions with size 500, 100, and 20, set "500,100,20". Notice that the number should have a descent order (e.g., setting it to 20,100,500 is incorrect). +n_facet_emb: 2 # (int) Number of the output hidden states for pointer network. This number should be either 0 or 2. +n_facet_hidden: 2 # (int) min(n_facet_hidden, n_layers) = H hyperparameter in multiple input hidden states (Mi) [3]. If not using Mi, set this number to 1. +n_facet_window: -2 # (int) -n_facet_window + 1 is the W hyperparameter in multiple input hidden states [3]. If not using Mi, set this number to 0. +n_facet_MLP: -1 # (int) The dimension of q_ct in [3] is (-n_facet_MLP + 1)*embedding_size. If not using Mi, set this number to 0. +weight_mode: '' # (str) The method of merging probability distribution in MoS. The value could be "dynamic" [4], "static", and "max_logits" [1]. +context_norm: 1 # (int) If setting 0, we remove the denominator in Equation (5) of [1]. +partition_merging_mode: 'replace' # (str) If "replace", the logit from context partition and pointer network would overwrite the logit from reranker partition and original softmax. Otherwise, the logit would be added. +reranker_merging_mode: 'replace' # (str) If "add", the logit from reranker partition would be added with the original softmax. Otherwise, the softmax logit would be replaced by the logit from reranker partition. +use_proj_bias: 1 # (bool) In linear layers for all output hidden states, if we want to use the bias term. +post_remove_context: 0 # (int) Setting the probability of all the items in the history to be 0 [2]. + +#[1] Haw-Shiuan Chang, Nikhil Agarwal, and Andrew McCallum. "To Copy, or not to Copy; That is a Critical Issue of the Output Softmax Layer in Neural Sequential Recommenders." In Proceedings of The 17th ACM Inernational Conference on Web Search and Data Mining (WSDM 24) +#[2] Ming Li, Ali Vardasbi, Andrew Yates, and Maarten de Rijke. 2023. Repetition and Exploration in Sequential Recommendation. In SIGIR 2023: 46th international ACM SIGIR Conference on Research and Development in Information Retrieval. ACM, 2532–2541. +#[3] Haw-Shiuan Chang and Andrew McCallum. 2022. Softmax bottleneck makes language models unable to represent multi-mode word distributions. In Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). 8048–8073 +#[4] Zhilin Yang, Zihang Dai, Ruslan Salakhutdinov, and William W. Cohen. "Breaking the Softmax Bottleneck: A High-Rank RNN Language Model." In International Conference on Learning Representations. 2018. diff --git a/tests/model/test_model_auto.py b/tests/model/test_model_auto.py index 310f7c158..0462e881a 100644 --- a/tests/model/test_model_auto.py +++ b/tests/model/test_model_auto.py @@ -458,6 +458,10 @@ def test_fpmc(self): def test_gru4rec(self): config_dict = {"model": "GRU4Rec", "train_neg_sample_args": None} quick_test(config_dict) + + def test_gru4reccpr(self): + config_dict = {"model": "GRU4RecCPR", "train_neg_sample_args": None} + quick_test(config_dict) def test_gru4rec_with_BPR_loss(self): config_dict = { @@ -531,6 +535,10 @@ def test_transrec(self): def test_sasrec(self): config_dict = {"model": "SASRec", "train_neg_sample_args": None} quick_test(config_dict) + + def test_sasreccpr(self): + config_dict = {"model": "SASRecCPR", "train_neg_sample_args": None} + quick_test(config_dict) def test_sasrec_with_BPR_loss_and_relu(self): config_dict = {"model": "SASRec", "loss_type": "BPR", "hidden_act": "relu"}