Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEA: Add NCL #5

Merged
merged 1 commit into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions recbole_graph/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from recbole_graph.model.general_recommender.lightgcn import LightGCN
from recbole_graph.model.general_recommender.ncl import NCL
from recbole_graph.model.general_recommender.sgl import SGL
223 changes: 223 additions & 0 deletions recbole_graph/model/general_recommender/ncl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# -*- coding: utf-8 -*-
r"""
NCL
################################################
Reference:
Zihan Lin*, Changxin Tian*, Yupeng Hou*, Wayne Xin Zhao. "Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning." in WWW 2022.
"""

import numpy as np
import torch
import torch.nn.functional as F

import faiss
from recbole.model.init import xavier_uniform_initialization
from recbole.model.loss import BPRLoss, EmbLoss
from recbole.utils import InputType

from recbole_graph.model.abstract_recommender import GeneralGraphRecommender
from recbole_graph.model.layers import GCNConv


class NCL(GeneralGraphRecommender):
input_type = InputType.PAIRWISE

def __init__(self, config, dataset):
super(NCL, self).__init__(config, dataset)

# load parameters info
self.latent_dim = config['embedding_size'] # int type: the embedding size of the base model
self.n_layers = config['n_layers'] # int type: the layer num of the base model
self.reg_weight = config['reg_weight'] # float32 type: the weight decay for l2 normalization

self.ssl_temp = config['ssl_temp']
self.ssl_reg = config['ssl_reg']
self.hyper_layers = config['hyper_layers']

self.alpha = config['alpha']

self.proto_reg = config['proto_reg']
self.k = config['num_clusters']

# define layers and loss
self.user_embedding = torch.nn.Embedding(num_embeddings=self.n_users, embedding_dim=self.latent_dim)
self.item_embedding = torch.nn.Embedding(num_embeddings=self.n_items, embedding_dim=self.latent_dim)
self.gcn_conv = GCNConv(dim=self.latent_dim)
self.mf_loss = BPRLoss()
self.reg_loss = EmbLoss()

# storage variables for full sort evaluation acceleration
self.restore_user_e = None
self.restore_item_e = None

# parameters initialization
self.apply(xavier_uniform_initialization)
self.other_parameter_name = ['restore_user_e', 'restore_item_e']

self.user_centroids = None
self.user_2cluster = None
self.item_centroids = None
self.item_2cluster = None

def e_step(self):
user_embeddings = self.user_embedding.weight.detach().cpu().numpy()
item_embeddings = self.item_embedding.weight.detach().cpu().numpy()
self.user_centroids, self.user_2cluster = self.run_kmeans(user_embeddings)
self.item_centroids, self.item_2cluster = self.run_kmeans(item_embeddings)

def run_kmeans(self, x):
"""Run K-means algorithm to get k clusters of the input tensor x
"""
kmeans = faiss.Kmeans(d=self.latent_dim, k=self.k, gpu=True)
kmeans.train(x)
cluster_cents = kmeans.centroids

_, I = kmeans.index.search(x, 1)

# convert to cuda Tensors for broadcast
centroids = torch.Tensor(cluster_cents).to(self.device)
centroids = F.normalize(centroids, p=2, dim=1)

node2cluster = torch.LongTensor(I).squeeze().to(self.device)
return centroids, node2cluster

def get_ego_embeddings(self):
r"""Get the embedding of users and items and combine to an embedding matrix.
Returns:
Tensor of the embedding matrix. Shape of [n_items+n_users, embedding_dim]
"""
user_embeddings = self.user_embedding.weight
item_embeddings = self.item_embedding.weight
ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
return ego_embeddings

def forward(self):
all_embeddings = self.get_ego_embeddings()
embeddings_list = [all_embeddings]
for layer_idx in range(max(self.n_layers, self.hyper_layers*2)):
all_embeddings = self.gcn_conv(all_embeddings, self.edge_index, self.edge_weight)
embeddings_list.append(all_embeddings)

lightgcn_all_embeddings = torch.stack(embeddings_list[:self.n_layers+1], dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
return user_all_embeddings, item_all_embeddings, embeddings_list

def ProtoNCE_loss(self, node_embedding, user, item):
user_embeddings_all, item_embeddings_all = torch.split(node_embedding, [self.n_users, self.n_items])

user_embeddings = user_embeddings_all[user] # [B, e]
norm_user_embeddings = F.normalize(user_embeddings)

user2cluster = self.user_2cluster[user] # [B,]
user2centroids = self.user_centroids[user2cluster] # [B, e]
pos_score_user = torch.mul(norm_user_embeddings, user2centroids).sum(dim=1)
pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
ttl_score_user = torch.matmul(norm_user_embeddings, self.user_centroids.transpose(0, 1))
ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)

proto_nce_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()

item_embeddings = item_embeddings_all[item]
norm_item_embeddings = F.normalize(item_embeddings)

item2cluster = self.item_2cluster[item] # [B, ]
item2centroids = self.item_centroids[item2cluster] # [B, e]
pos_score_item = torch.mul(norm_item_embeddings, item2centroids).sum(dim=1)
pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
ttl_score_item = torch.matmul(norm_item_embeddings, self.item_centroids.transpose(0, 1))
ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)
proto_nce_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

proto_nce_loss = self.proto_reg * (proto_nce_loss_user + proto_nce_loss_item)
return proto_nce_loss

def ssl_layer_loss(self, current_embedding, previous_embedding, user, item):
current_user_embeddings, current_item_embeddings = torch.split(current_embedding, [self.n_users, self.n_items])
previous_user_embeddings_all, previous_item_embeddings_all = torch.split(previous_embedding, [self.n_users, self.n_items])

current_user_embeddings = current_user_embeddings[user]
previous_user_embeddings = previous_user_embeddings_all[user]
norm_user_emb1 = F.normalize(current_user_embeddings)
norm_user_emb2 = F.normalize(previous_user_embeddings)
norm_all_user_emb = F.normalize(previous_user_embeddings_all)
pos_score_user = torch.mul(norm_user_emb1, norm_user_emb2).sum(dim=1)
ttl_score_user = torch.matmul(norm_user_emb1, norm_all_user_emb.transpose(0, 1))
pos_score_user = torch.exp(pos_score_user / self.ssl_temp)
ttl_score_user = torch.exp(ttl_score_user / self.ssl_temp).sum(dim=1)

ssl_loss_user = -torch.log(pos_score_user / ttl_score_user).sum()

current_item_embeddings = current_item_embeddings[item]
previous_item_embeddings = previous_item_embeddings_all[item]
norm_item_emb1 = F.normalize(current_item_embeddings)
norm_item_emb2 = F.normalize(previous_item_embeddings)
norm_all_item_emb = F.normalize(previous_item_embeddings_all)
pos_score_item = torch.mul(norm_item_emb1, norm_item_emb2).sum(dim=1)
ttl_score_item = torch.matmul(norm_item_emb1, norm_all_item_emb.transpose(0, 1))
pos_score_item = torch.exp(pos_score_item / self.ssl_temp)
ttl_score_item = torch.exp(ttl_score_item / self.ssl_temp).sum(dim=1)

ssl_loss_item = -torch.log(pos_score_item / ttl_score_item).sum()

ssl_loss = self.ssl_reg * (ssl_loss_user + self.alpha * ssl_loss_item)
return ssl_loss

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]

user_all_embeddings, item_all_embeddings, embeddings_list = self.forward()

center_embedding = embeddings_list[0]
context_embedding = embeddings_list[self.hyper_layers * 2]

ssl_loss = self.ssl_layer_loss(context_embedding, center_embedding, user, pos_item)
proto_loss = self.ProtoNCE_loss(center_embedding, user, pos_item)

u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)

mf_loss = self.mf_loss(pos_scores, neg_scores)

u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)

reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings)

return mf_loss + self.reg_weight * reg_loss, ssl_loss, proto_loss

def predict(self, interaction):
user = interaction[self.USER_ID]
item = interaction[self.ITEM_ID]

user_all_embeddings, item_all_embeddings, embeddings_list = self.forward()

u_embeddings = user_all_embeddings[user]
i_embeddings = item_all_embeddings[item]
scores = torch.mul(u_embeddings, i_embeddings).sum(dim=1)
return scores

def full_sort_predict(self, interaction):
user = interaction[self.USER_ID]
if self.restore_user_e is None or self.restore_item_e is None:
self.restore_user_e, self.restore_item_e, embedding_list = self.forward()
# get user embedding from storage variable
u_embeddings = self.restore_user_e[user]

# dot with all item embedding to accelerate
scores = torch.matmul(u_embeddings, self.restore_item_e.transpose(0, 1))

return scores.view(-1)
4 changes: 2 additions & 2 deletions recbole_graph/quick_start.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from logging import getLogger
from recbole.utils import init_logger, get_trainer, init_seed, set_color
from recbole.utils import init_logger, init_seed, set_color

from recbole_graph.config import Config
from recbole_graph.utils import create_dataset, data_preparation, get_model
from recbole_graph.utils import create_dataset, data_preparation, get_model, get_trainer


def run_recbole_graph(model=None, dataset=None, config_file_list=None, config_dict=None, saved=True):
Expand Down
143 changes: 143 additions & 0 deletions recbole_graph/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from time import time
from torch.nn.utils.clip_grad import clip_grad_norm_
from tqdm import tqdm
from recbole.trainer import Trainer
from recbole.utils import early_stopping, dict2str, set_color, get_gpu_usage


class NCLTrainer(Trainer):
def __init__(self, config, model):
super(NCLTrainer, self).__init__(config, model)

self.num_m_step = config['m_step']
assert self.num_m_step is not None

def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None):
r"""Train the model based on the train data and the valid data.
Args:
train_data (DataLoader): the train data
valid_data (DataLoader, optional): the valid data, default: None.
If it's None, the early_stopping is invalid.
verbose (bool, optional): whether to write training and evaluation information to logger, default: True
saved (bool, optional): whether to save the model parameters, default: True
show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
callback_fn (callable): Optional callback function executed at end of epoch.
Includes (epoch_idx, valid_score) input arguments.
Returns:
(float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
"""
if saved and self.start_epoch >= self.epochs:
self._save_checkpoint(-1)

self.eval_collector.data_collect(train_data)

for epoch_idx in range(self.start_epoch, self.epochs):

# only differences from the original trainer
if epoch_idx % self.num_m_step == 0:
self.logger.info("Running E-step ! ")
self.model.e_step()
# train
training_start_time = time()
train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
self.train_loss_dict[epoch_idx] = sum(train_loss) if isinstance(train_loss, tuple) else train_loss
training_end_time = time()
train_loss_output = \
self._generate_train_loss_output(epoch_idx, training_start_time, training_end_time, train_loss)
if verbose:
self.logger.info(train_loss_output)
self._add_train_loss_to_tensorboard(epoch_idx, train_loss)

# eval
if self.eval_step <= 0 or not valid_data:
if saved:
self._save_checkpoint(epoch_idx)
update_output = set_color('Saving current', 'blue') + ': %s' % self.saved_model_file
if verbose:
self.logger.info(update_output)
continue
if (epoch_idx + 1) % self.eval_step == 0:
valid_start_time = time()
valid_score, valid_result = self._valid_epoch(valid_data, show_progress=show_progress)
self.best_valid_score, self.cur_step, stop_flag, update_flag = early_stopping(
valid_score,
self.best_valid_score,
self.cur_step,
max_step=self.stopping_step,
bigger=self.valid_metric_bigger
)
valid_end_time = time()
valid_score_output = (set_color("epoch %d evaluating", 'green') + " [" + set_color("time", 'blue')
+ ": %.2fs, " + set_color("valid_score", 'blue') + ": %f]") % \
(epoch_idx, valid_end_time - valid_start_time, valid_score)
valid_result_output = set_color('valid result', 'blue') + ': \n' + dict2str(valid_result)
if verbose:
self.logger.info(valid_score_output)
self.logger.info(valid_result_output)
self.tensorboard.add_scalar('Vaild_score', valid_score, epoch_idx)

if update_flag:
if saved:
self._save_checkpoint(epoch_idx)
update_output = set_color('Saving current best', 'blue') + ': %s' % self.saved_model_file
if verbose:
self.logger.info(update_output)
self.best_valid_result = valid_result

if callback_fn:
callback_fn(epoch_idx, valid_score)

if stop_flag:
stop_output = 'Finished training, best eval result in epoch %d' % \
(epoch_idx - self.cur_step * self.eval_step)
if verbose:
self.logger.info(stop_output)
break
self._add_hparam_to_tensorboard(self.best_valid_score)
return self.best_valid_score, self.best_valid_result

def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
r"""Train the model in an epoch
Args:
train_data (DataLoader): The train data.
epoch_idx (int): The current epoch id.
loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
:attr:`self.model.calculate_loss`. Defaults to ``None``.
show_progress (bool): Show the progress of training epoch. Defaults to ``False``.
Returns:
float/tuple: The sum of loss returned by all batches in this epoch. If the loss in each batch contains
multiple parts and the model return these multiple parts loss instead of the sum of loss, it will return a
tuple which includes the sum of loss in each part.
"""
self.model.train()
loss_func = loss_func or self.model.calculate_loss
total_loss = None
iter_data = (
tqdm(
train_data,
total=len(train_data),
ncols=100,
desc=set_color(f"Train {epoch_idx:>5}", 'pink'),
) if show_progress else train_data
)
for batch_idx, interaction in enumerate(iter_data):
interaction = interaction.to(self.device)
self.optimizer.zero_grad()
losses = loss_func(interaction)
if isinstance(losses, tuple):
if epoch_idx < self.config['warm_up_step']:
losses = losses[:-1]
loss = sum(losses)
loss_tuple = tuple(per_loss.item() for per_loss in losses)
total_loss = loss_tuple if total_loss is None else tuple(map(sum, zip(total_loss, loss_tuple)))
else:
loss = losses
total_loss = losses.item() if total_loss is None else total_loss + losses.item()
self._check_nan(loss)
loss.backward()
if self.clip_grad_norm:
clip_grad_norm_(self.model.parameters(), **self.clip_grad_norm)
self.optimizer.step()
if self.gpu_available and show_progress:
iter_data.set_postfix_str(set_color('GPU RAM: ' + get_gpu_usage(self.device), 'yellow'))
return total_loss
Loading