diff --git a/recbole/model/context_aware_recommender/ffm.py b/recbole/model/context_aware_recommender/ffm.py index c16e8e90e..63511959c 100644 --- a/recbole/model/context_aware_recommender/ffm.py +++ b/recbole/model/context_aware_recommender/ffm.py @@ -165,7 +165,7 @@ def __init__(self, feature_names, feature_dims, feature2id, feature2field, num_f # init float field-aware embeddings if there is float type of features. if len(self.float_feature_names) > 0: self.num_float_features = len(self.float_feature_names) - self.float_embeddings = nn.Embedding(np.sum(self.token_feature_dims, dtype=np.int32), self.embed_dim) + # self.float_embeddings = nn.Embedding(np.sum(self.token_feature_dims, dtype=np.int32), self.embed_dim) self.float_embeddings = torch.nn.ModuleList([ nn.Embedding(self.num_float_features, self.embed_dim) for _ in range(self.num_fields) ]) diff --git a/recbole/model/context_aware_recommender/fwfm.py b/recbole/model/context_aware_recommender/fwfm.py index deccb63df..6822950ed 100644 --- a/recbole/model/context_aware_recommender/fwfm.py +++ b/recbole/model/context_aware_recommender/fwfm.py @@ -51,6 +51,8 @@ def __init__(self, config, dataset): self.num_fields = len(set(self.feature2field.values())) # the number of fields self.num_pair = self.num_fields * self.num_fields + self.weight = torch.randn(self.num_fields,self.num_fields,1,requires_grad=True,device=self.device) + self.loss = nn.BCELoss() # parameters initialization @@ -100,17 +102,13 @@ def fwfm_layer(self, infeature): """ # get r(Fi, Fj) batch_size = infeature.shape[0] - para = torch.randn(self.num_fields * self.num_fields * self.embedding_size).\ - expand(batch_size, self.num_fields * self.num_fields * self.embedding_size).\ - to(self.device) # [batch_size*num_pairs*emb_dim] - para = para.reshape(batch_size, self.num_fields, self.num_fields, self.embedding_size) - r = nn.Parameter(para, requires_grad=True) # [batch_size, num_fields, num_fields, emb_dim] + weight = self.weight.expand(batch_size,-1,-1,-1) fwfm_inter = list() # [batch_size, num_fields, emb_dim] for i in range(self.num_features - 1): for j in range(i + 1, self.num_features): Fi, Fj = self.feature2field[i], self.feature2field[j] - fwfm_inter.append(infeature[:, i] * infeature[:, j] * r[:, Fi, Fj]) + fwfm_inter.append(infeature[:, i] * infeature[:, j] * weight[:, Fi, Fj]) fwfm_inter = torch.stack(fwfm_inter, dim=1) fwfm_inter = torch.sum(fwfm_inter, dim=1) # [batch_size, emb_dim] fwfm_inter = self.dropout_layer(fwfm_inter) diff --git a/recbole/model/layers.py b/recbole/model/layers.py index f9fe94a29..24ecfde75 100644 --- a/recbole/model/layers.py +++ b/recbole/model/layers.py @@ -269,7 +269,7 @@ def __init__( self.softmax_stag = softmax_stag self.return_seq_weight = return_seq_weight self.mask_mat = mask_mat - self.att_mlp_layers = MLPLayers(self.att_hidden_size, activation='Sigmoid', bn=False) + self.att_mlp_layers = MLPLayers(self.att_hidden_size, activation=self.activation, bn=False) self.dense = nn.Linear(self.att_hidden_size[-1], 1) def forward(self, queries, keys, keys_length):