Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: Rename GCNConv to LightGCNConv; Add BiGNNCov; Add NGCF; #6

Merged
merged 3 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recbole_graph/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from recbole_graph.model.general_recommender.lightgcn import LightGCN
from recbole_graph.model.general_recommender.ncl import NCL
from recbole_graph.model.general_recommender.ngcf import NGCF
from recbole_graph.model.general_recommender.sgl import SGL
4 changes: 2 additions & 2 deletions recbole_graph/model/general_recommender/lightgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from recbole.utils import InputType

from recbole_graph.model.abstract_recommender import GeneralGraphRecommender
from recbole_graph.model.layers import GCNConv
from recbole_graph.model.layers import LightGCNConv


class LightGCN(GeneralGraphRecommender):
Expand All @@ -31,7 +31,7 @@ def __init__(self, config, dataset):
# define layers and loss
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
self.gcn_conv = GCNConv(dim=self.latent_dim)
self.gcn_conv = LightGCNConv(dim=self.latent_dim)
self.mf_loss = BPRLoss()
self.reg_loss = EmbLoss()

Expand Down
8 changes: 4 additions & 4 deletions recbole_graph/model/general_recommender/ncl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from recbole.utils import InputType

from recbole_graph.model.abstract_recommender import GeneralGraphRecommender
from recbole_graph.model.layers import GCNConv
from recbole_graph.model.layers import LightGCNConv


class NCL(GeneralGraphRecommender):
Expand All @@ -42,7 +42,7 @@ def __init__(self, config, dataset):
# define layers and loss
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
self.gcn_conv = GCNConv(dim=self.latent_dim)
self.gcn_conv = LightGCNConv(dim=self.latent_dim)
self.mf_loss = BPRLoss()
self.reg_loss = EmbLoss()

Expand Down Expand Up @@ -94,11 +94,11 @@ def get_ego_embeddings(self):
def forward(self):
all_embeddings = self.get_ego_embeddings()
embeddings_list = [all_embeddings]
for layer_idx in range(max(self.n_layers, self.hyper_layers*2)):
for layer_idx in range(max(self.n_layers, self.hyper_layers * 2)):
all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
embeddings_list.append(all_embeddings)

lightgcn_all_embeddings = torch.stack(embeddings_list[:self.n_layers+1], dim=1)
lightgcn_all_embeddings = torch.stack(embeddings_list[:self.n_layers + 1], dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
Expand Down
136 changes: 136 additions & 0 deletions recbole_graph/model/general_recommender/ngcf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# @Time : 2022/3/8
# @Author : Changxin Tian
# @Email : cx.tian@outlook.com
r"""
NGCF
################################################
Reference:
Xiang Wang et al. "Neural Graph Collaborative Filtering." in SIGIR 2019.

Reference code:
https://github.com/xiangwang1223/neural_graph_collaborative_filtering

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dropout_adj

from recbole.model.init import xavier_normal_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_graph.model.abstract_recommender import GeneralGraphRecommender
from recbole_graph.model.layers import BiGNNConv


class NGCF(GeneralGraphRecommender):
r"""NGCF is a model that incorporate GNN for recommendation.
We implement the model following the original author with a pairwise training mode.
"""
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(NGCF, self).__init__(config, dataset)

# load parameters info
self.embedding_size = config['embedding_size']
self.hidden_size_list = config['hidden_size_list']
self.hidden_size_list = [self.embedding_size] + self.hidden_size_list
self.node_dropout = config['node_dropout']
self.message_dropout = config['message_dropout']
self.reg_weight = config['reg_weight']

# define layers and loss
self.user_embedding = nn.Embedding(self.n_users, self.embedding_size)
self.item_embedding = nn.Embedding(self.n_items, self.embedding_size)
self.GNNlayers = torch.nn.ModuleList()
for input_size, output_size in zip(self.hidden_size_list[:-1], self.hidden_size_list[1:]):
self.GNNlayers.append(BiGNNConv(input_size, output_size))
self.mf_loss = BPRLoss()
self.reg_loss = EmbLoss()

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_normal_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e']

def get_ego_embeddings(self):
r"""Get the embedding of users and items and combine to an embedding matrix.

Returns:
Tensor of the embedding matrix. Shape of (n_items+n_users, embedding_dim)
"""
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
return ego_embeddings

def forward(self):
if self.node_dropout == 0:
edge_index, edge_weight = self.edge_index, self.edge_weight
else:
edge_index, edge_weight = dropout_adj(edge_index=self.edge_index, edge_attr=self.edge_weight, p=self.node_dropout)

all_embeddings = self.get_ego_embeddings()
embeddings_list = [all_embeddings]
for gnn in self.GNNlayers:
all_embeddings = gnn(all_embeddings, edge_index, edge_weight)
all_embeddings = nn.LeakyReLU(negative_slope=0.2)(all_embeddings)
all_embeddings = nn.Dropout(self.message_dropout)(all_embeddings)
all_embeddings = F.normalize(all_embeddings, p=2, dim=1)
embeddings_list += [all_embeddings] # storage output embedding of each layer
ngcf_all_embeddings = torch.cat(embeddings_list, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(ngcf_all_embeddings, [self.n_users, self.n_items])

return user_all_embeddings, item_all_embeddings

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]

user_all_embeddings, item_all_embeddings = self.forward()
u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores) # calculate BPR Loss

reg_loss = self.reg_loss(u_embeddings, pos_embeddings, neg_embeddings) # L2 regularization of embeddings

return mf_loss + self.reg_weight * reg_loss

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_all_embeddings, item_all_embeddings = self.forward()

u_embeddings = user_all_embeddings[user]
i_embeddings = item_all_embeddings[item]
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
return scores

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e = self.forward()
# get user embedding from storage variable
u_embeddings = self.restore_user_e[user]

# dot with all item embedding to accelerate
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

return scores.view(-1)
4 changes: 2 additions & 2 deletions recbole_graph/model/general_recommender/sgl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from recbole.utils import InputType

from recbole_graph.model.abstract_recommender import GeneralGraphRecommender
from recbole_graph.model.layers import GCNConv
from recbole_graph.model.layers import LightGCNConv


class SGL(GeneralGraphRecommender):
Expand Down Expand Up @@ -57,7 +57,7 @@ def __init__(self, config, dataset):
# define layers and loss
self.user_embedding = torch.nn.Embedding(self.n_users, self.latent_dim)
self.item_embedding = torch.nn.Embedding(self.n_items, self.latent_dim)
self.gcn_conv = GCNConv(dim=self.latent_dim)
self.gcn_conv = LightGCNConv(dim=self.latent_dim)
self.reg_loss = EmbLoss()

# storage variables for full sort evaluation acceleration
Expand Down
34 changes: 32 additions & 2 deletions recbole_graph/model/layers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops


class GCNConv(MessagePassing):
class LightGCNConv(MessagePassing):
def __init__(self, dim):
super(GCNConv, self).__init__(aggr='add')
super(LightGCNConv, self).__init__(aggr='add')
self.dim = dim

def forward(self, x, edge_index, edge_weight):
Expand All @@ -18,6 +19,35 @@ def __repr__(self):
return '{}({})'.format(self.__class__.__name__, self.dim)


class BiGNNConv(MessagePassing):
r"""Propagate a layer of Bi-interaction GNN

.. math::
output = (L+I)EW_1 + LE \otimes EW_2
"""
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add')
self.in_channels, self.out_channels = in_channels, out_channels
self.lin1 = torch.nn.Linear(in_channels, out_channels)
self.lin2 = torch.nn.Linear(in_channels, out_channels)

def forward(self, x, edge_index, edge_weight):
return self.propagate(edge_index, x=x, edge_weight=edge_weight)

def message(self, x_i, x_j, edge_weight):
x_trans = self.lin1(x_j)
x_inter = self.lin2(torch.mul(x_j, x_i))
x = x_trans + x_inter
x_prop = edge_weight.view(-1, 1) * x
return x_prop

def update(self, aggr_out, x):
return aggr_out + self.lin1(x)

def __repr__(self):
return '{}({},{})'.format(self.__class__.__name__, self.in_channels, self.out_channels)


class SRGNNConv(MessagePassing):
def __init__(self, dim):
# mean aggregation to incorporate weight naturally
Expand Down