Skip to content

Commit

Permalink
Merge pull request #1414 from TangJiakai/master
Browse files Browse the repository at this point in the history
FwFMs Bug fix
  • Loading branch information
Sherry-XLL authored Aug 26, 2022
2 parents 07d80fe + 3c9bf41 commit 4625878
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion recbole/model/context_aware_recommender/ffm.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self, feature_names, feature_dims, feature2id, feature2field, num_f
# init float field-aware embeddings if there is float type of features.
if len(self.float_feature_names) > 0:
self.num_float_features = len(self.float_feature_names)
self.float_embeddings = nn.Embedding(np.sum(self.token_feature_dims, dtype=np.int32), self.embed_dim)
# self.float_embeddings = nn.Embedding(np.sum(self.token_feature_dims, dtype=np.int32), self.embed_dim)
self.float_embeddings = torch.nn.ModuleList([
nn.Embedding(self.num_float_features, self.embed_dim) for _ in range(self.num_fields)
])
Expand Down
10 changes: 4 additions & 6 deletions recbole/model/context_aware_recommender/fwfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def __init__(self, config, dataset):
self.num_fields = len(set(self.feature2field.values())) # the number of fields
self.num_pair = self.num_fields * self.num_fields

self.weight = torch.randn(self.num_fields,self.num_fields,1,requires_grad=True,device=self.device)

self.loss = nn.BCELoss()

# parameters initialization
Expand Down Expand Up @@ -100,17 +102,13 @@ def fwfm_layer(self, infeature):
"""
# get r(Fi, Fj)
batch_size = infeature.shape[0]
para = torch.randn(self.num_fields * self.num_fields * self.embedding_size).\
expand(batch_size, self.num_fields * self.num_fields * self.embedding_size).\
to(self.device) # [batch_size*num_pairs*emb_dim]
para = para.reshape(batch_size, self.num_fields, self.num_fields, self.embedding_size)
r = nn.Parameter(para, requires_grad=True) # [batch_size, num_fields, num_fields, emb_dim]
weight = self.weight.expand(batch_size,-1,-1,-1)

fwfm_inter = list() # [batch_size, num_fields, emb_dim]
for i in range(self.num_features - 1):
for j in range(i + 1, self.num_features):
Fi, Fj = self.feature2field[i], self.feature2field[j]
fwfm_inter.append(infeature[:, i] * infeature[:, j] * r[:, Fi, Fj])
fwfm_inter.append(infeature[:, i] * infeature[:, j] * weight[:, Fi, Fj])
fwfm_inter = torch.stack(fwfm_inter, dim=1)
fwfm_inter = torch.sum(fwfm_inter, dim=1) # [batch_size, emb_dim]
fwfm_inter = self.dropout_layer(fwfm_inter)
Expand Down
2 changes: 1 addition & 1 deletion recbole/model/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def __init__(
self.softmax_stag = softmax_stag
self.return_seq_weight = return_seq_weight
self.mask_mat = mask_mat
self.att_mlp_layers = MLPLayers(self.att_hidden_size, activation='Sigmoid', bn=False)
self.att_mlp_layers = MLPLayers(self.att_hidden_size, activation=self.activation, bn=False)
self.dense = nn.Linear(self.att_hidden_size[-1], 1)

def forward(self, queries, keys, keys_length):
Expand Down

0 comments on commit 4625878

Please sign in to comment.