-
Notifications
You must be signed in to change notification settings - Fork 40
/
simgcl.py
60 lines (45 loc) · 2.46 KB
/
simgcl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding: utf-8 -*-
r"""
SimGCL
################################################
Reference:
Junliang Yu, Hongzhi Yin, Xin Xia, Tong Chen, Lizhen Cui, Quoc Viet Hung Nguyen. "Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation." in SIGIR 2022.
"""
import torch
import torch.nn.functional as F
from recbole_gnn.model.general_recommender import LightGCN
class SimGCL(LightGCN):
def __init__(self, config, dataset):
super(SimGCL, self).__init__(config, dataset)
self.cl_rate = config['lambda']
self.eps = config['eps']
self.temperature = config['temperature']
def forward(self, perturbed=False):
all_embs = self.get_ego_embeddings()
embeddings_list = []
for layer_idx in range(self.n_layers):
all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight)
if perturbed:
random_noise = torch.rand_like(all_embs, device=all_embs.device)
all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps
embeddings_list.append(all_embs)
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)
user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
return user_all_embeddings, item_all_embeddings
def calculate_cl_loss(self, x1, x2):
x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
pos_score = (x1 * x2).sum(dim=-1)
pos_score = torch.exp(pos_score / self.temperature)
ttl_score = torch.matmul(x1, x2.transpose(0, 1))
ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1)
return -torch.log(pos_score / ttl_score).sum()
def calculate_loss(self, interaction):
loss = super().calculate_loss(interaction)
user = torch.unique(interaction[self.USER_ID])
pos_item = torch.unique(interaction[self.ITEM_ID])
perturbed_user_embs_1, perturbed_item_embs_1 = self.forward(perturbed=True)
perturbed_user_embs_2, perturbed_item_embs_2 = self.forward(perturbed=True)
user_cl_loss = self.calculate_cl_loss(perturbed_user_embs_1[user], perturbed_user_embs_2[user])
item_cl_loss = self.calculate_cl_loss(perturbed_item_embs_1[pos_item], perturbed_item_embs_2[pos_item])
return loss + self.cl_rate * (user_cl_loss + item_cl_loss)