-
Notifications
You must be signed in to change notification settings - Fork 620
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1509 from CIEEMio/FiGNN
FEA: Add FiGNN in Context Aware models
- Loading branch information
Showing
4 changed files
with
163 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2022/10/27 | ||
# @Author : Yuyan Zhang | ||
# @Email : 2019308160102@cau.edu.cn | ||
# @File : fignn.py | ||
|
||
r""" | ||
FiGNN | ||
################################################ | ||
Reference: | ||
Li, Zekun, et al. "Fi-GNN: Modeling Feature Interactions via Graph Neural Networks for CTR Prediction" | ||
in CIKM 2019. | ||
Reference code: | ||
- https://github.com/CRIPAC-DIG/GraphCTR | ||
- https://github.com/xue-pai/FuxiCTR | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.init import xavier_uniform_, xavier_normal_, constant_ | ||
from itertools import product | ||
|
||
from recbole.utils import InputType | ||
from recbole.model.abstract_recommender import ContextRecommender | ||
|
||
|
||
class GraphLayer(nn.Module): | ||
""" | ||
The implementations of the GraphLayer part and the Attentional Edge Weights part are adapted from https://github.com/xue-pai/FuxiCTR. | ||
""" | ||
|
||
def __init__(self, num_fields, embedding_size): | ||
super(GraphLayer, self).__init__() | ||
self.W_in = nn.Parameter(torch.Tensor(num_fields, embedding_size, embedding_size)) | ||
self.W_out = nn.Parameter(torch.Tensor(num_fields, embedding_size, embedding_size)) | ||
xavier_normal_(self.W_in) | ||
xavier_normal_(self.W_out) | ||
self.bias_p = nn.Parameter(torch.zeros(embedding_size)) | ||
|
||
def forward(self, g, h): | ||
h_out = torch.matmul(self.W_out, h.unsqueeze(-1)).squeeze(-1) | ||
aggr = torch.bmm(g, h_out) | ||
a = torch.matmul(self.W_in, aggr.unsqueeze(-1)).squeeze(-1) + self.bias_p | ||
return a | ||
|
||
|
||
class FiGNN(ContextRecommender): | ||
""" FiGNN is a CTR prediction model based on GGNN, | ||
which can model sophisticated interactions among feature fields on the graph-structured features. | ||
""" | ||
input_type = InputType.PAIRWISE | ||
|
||
def __init__(self, config, dataset): | ||
super(FiGNN, self).__init__(config, dataset) | ||
|
||
# load parameters info | ||
self.attention_size = config['attention_size'] | ||
self.n_layers = config['n_layers'] | ||
self.num_heads = config['num_heads'] | ||
self.hidden_dropout_prob = config['hidden_dropout_prob'] | ||
self.attn_dropout_prob = config['attn_dropout_prob'] | ||
|
||
# define layers and loss | ||
self.dropout_layer = nn.Dropout(p=self.hidden_dropout_prob) | ||
self.att_embedding = nn.Linear(self.embedding_size, self.attention_size) | ||
# multi-head self-attention network | ||
self.self_attn = nn.MultiheadAttention( | ||
self.attention_size, self.num_heads, dropout=self.attn_dropout_prob, batch_first=True | ||
) | ||
self.v_res_embedding = torch.nn.Linear(self.embedding_size, self.attention_size) | ||
# FiGNN | ||
self.src_nodes, self.dst_nodes = zip(*list(product(range(self.num_feature_field), repeat=2))) | ||
self.gnn = nn.ModuleList([ | ||
GraphLayer(self.num_feature_field, self.attention_size) for _ in range(self.n_layers - 1) | ||
]) | ||
self.leaky_relu = nn.LeakyReLU(negative_slope=0.01) | ||
self.W_attn = nn.Linear(self.attention_size * 2, 1, bias=False) | ||
self.gru_cell = nn.GRUCell(self.attention_size, self.attention_size) | ||
# Attentional Scoring Layer | ||
self.mlp1 = nn.Linear(self.attention_size, 1, bias=False) | ||
self.mlp2 = nn.Linear(self.num_feature_field * self.attention_size, self.num_feature_field, bias=False) | ||
|
||
self.sigmoid = nn.Sigmoid() | ||
self.loss = nn.BCEWithLogitsLoss() | ||
# parameters initialization | ||
self.apply(self._init_weights) | ||
|
||
def fignn_layer(self, in_feature): | ||
|
||
emb_feature = self.att_embedding(in_feature) | ||
emb_feature = self.dropout_layer(emb_feature) | ||
# multi-head self-attention network | ||
att_feature, _ = self.self_attn(emb_feature, emb_feature, emb_feature) # [batch_size, num_field, att_dim] | ||
# Residual connection | ||
v_res = self.v_res_embedding(in_feature) | ||
att_feature += v_res | ||
att_feature = F.relu(att_feature).contiguous() | ||
|
||
# init graph | ||
src_emb = att_feature[:, self.src_nodes, :] | ||
dst_emb = att_feature[:, self.dst_nodes, :] | ||
concat_emb = torch.cat([src_emb, dst_emb], dim=-1) | ||
alpha = self.leaky_relu(self.W_attn(concat_emb)) | ||
alpha = alpha.view(-1, self.num_feature_field, self.num_feature_field) | ||
mask = torch.eye(self.num_feature_field).to(self.device) | ||
alpha = alpha.masked_fill(mask.bool(), float('-inf')) | ||
self.graph = F.softmax(alpha, dim=-1) | ||
# message passing | ||
if self.n_layers > 1: | ||
h = att_feature | ||
for i in range(self.n_layers - 1): | ||
a = self.gnn[i](self.graph, h) | ||
a = a.view(-1, self.attention_size) | ||
h = h.view(-1, self.attention_size) | ||
h = self.gru_cell(a, h) | ||
h = h.view(-1, self.num_feature_field, self.attention_size) | ||
h += att_feature | ||
else: | ||
h = att_feature | ||
# Attentional Scoring Layer | ||
score = self.mlp1(h).squeeze(-1) | ||
weight = self.mlp2(h.flatten(start_dim=1)) | ||
logit = (weight * score).sum(dim=1).unsqueeze(-1) | ||
return logit | ||
|
||
def _init_weights(self, module): | ||
if isinstance(module, nn.Embedding): | ||
xavier_normal_(module.weight.data) | ||
elif isinstance(module, nn.Linear): | ||
xavier_normal_(module.weight.data) | ||
if module.bias is not None: | ||
constant_(module.bias.data, 0) | ||
elif isinstance(module, nn.GRU): | ||
xavier_uniform_(module.weight_hh_l0) | ||
xavier_uniform_(module.weight_ih_l0) | ||
|
||
def forward(self, interaction): | ||
fignn_all_embeddings = self.concat_embed_input_fields(interaction) # [batch_size, num_field, embed_dim] | ||
output = self.fignn_layer(fignn_all_embeddings) | ||
return output.squeeze(1) | ||
|
||
def calculate_loss(self, interaction): | ||
label = interaction[self.LABEL] | ||
output = self.forward(interaction) | ||
return self.loss(output, label) | ||
|
||
def predict(self, interaction): | ||
return self.sigmoid(self.forward(interaction)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
embedding_size: 10 # (int) The embedding size of features. | ||
attention_size: 16 # (int) The vector size in attention mechanism. | ||
n_layers: 2 # (int) The number of layers. | ||
num_heads: 2 # (int) The number of attention heads. | ||
hidden_dropout_prob: 0.2 # (float) The dropout rate of hidden layer. | ||
attn_dropout_prob: 0.2 # (float) The dropout rate of multi-head self-attention layer. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters