Skip to content

Commit

Permalink
Merge pull request #72 from downeykking/main
Browse files Browse the repository at this point in the history
FEA: add XSimGCL
  • Loading branch information
hyp1231 authored Oct 18, 2023
2 parents ebb0136 + 8e755af commit 7c6783f
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install pytest
pip install dgl
pip install dgl==0.9.1
pip install torch==${{ matrix.torch-version}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
pip install torch-scatter==2.0.9 torch-sparse==0.6.15 torch-cluster==1.6.0 torch-spline-conv==1.2.1 torch-geometric -f https://data.pyg.org/whl/torch-${{ matrix.torch-version }}+cpu.html
pip install recbole==1.1.1
conda install -c conda-forge faiss-cpu
# Use "python -m pytest" instead of "pytest" to fix imports
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ We list currently supported models according to category:
* **[HMLET](recbole_gnn/model/general_recommender/hmlet.py)** from Kong *et al.*: [Linear, or Non-Linear, That is the Question!](https://arxiv.org/abs/2111.07265) (WSDM 2022).
* **[NCL](recbole_gnn/model/general_recommender/ncl.py)** from Lin *et al.*: [Improving Graph Collaborative Filtering with Neighborhood-enriched Contrastive Learning](https://arxiv.org/abs/2202.06200) (TheWebConf 2022).
* **[SimGCL](recbole_gnn/model/general_recommender/simgcl.py)** from Yu *et al.*: [Are Graph Augmentations Necessary? Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2112.08679) (SIGIR 2022).
* **[XSimGCL](recbole_gnn/model/general_recommender/xsimgcl.py)** from Yu *et al.*: [XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation](https://arxiv.org/abs/2209.02544) (TKDE 2023).

**Sequential Recommendation**:

Expand Down
8 changes: 5 additions & 3 deletions recbole_gnn/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import recbole
from recbole.config.configurator import Config as RecBole_Config
from recbole.utils import ModelType as RecBoleModelType

Expand All @@ -16,7 +17,8 @@ def __init__(self, model=None, dataset=None, config_file_list=None, config_dict=
config_file_list (list of str): the external config file, it allows multiple config files, default is None.
config_dict (dict): the external parameter dictionaries, default is None.
"""
self.compatibility_settings()
if recbole.__version__ == "1.1.1":
self.compatibility_settings()
super(Config, self).__init__(model, dataset, config_file_list, config_dict)

def compatibility_settings(self):
Expand Down Expand Up @@ -59,7 +61,7 @@ def _get_model_and_dataset(self, model, dataset):
final_dataset = dataset

return final_model, final_model_class, final_dataset

def _load_internal_config_dict(self, model, model_class, dataset):
super()._load_internal_config_dict(model, model_class, dataset)
current_path = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -75,4 +77,4 @@ def _load_internal_config_dict(self, model, model_class, dataset):
if self.internal_config_dict['MODEL_TYPE'] == RecBoleModelType.SEQUENTIAL:
self._update_internal_config_dict(sequential_base_init)
if self.internal_config_dict['MODEL_TYPE'] == ModelType.SOCIAL:
self._update_internal_config_dict(social_base_init)
self._update_internal_config_dict(social_base_init)
2 changes: 2 additions & 0 deletions recbole_gnn/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from recbole_gnn.model.general_recommender.ncl import NCL
from recbole_gnn.model.general_recommender.ngcf import NGCF
from recbole_gnn.model.general_recommender.sgl import SGL
from recbole_gnn.model.general_recommender.simgcl import SimGCL
from recbole_gnn.model.general_recommender.xsimgcl import XSimGCL
90 changes: 90 additions & 0 deletions recbole_gnn/model/general_recommender/xsimgcl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
r"""
XSimGCL
################################################
Reference:
Junliang Yu, Xin Xia, Tong Chen, Lizhen Cui, Nguyen Quoc Viet Hung, Hongzhi Yin. "XSimGCL: Towards Extremely Simple Graph Contrastive Learning for Recommendation" in TKDE 2023.
Reference code:
https://github.com/Coder-Yu/SELFRec/blob/main/model/graph/XSimGCL.py
"""


import torch
import torch.nn.functional as F

from recbole_gnn.model.general_recommender import LightGCN


class XSimGCL(LightGCN):
def __init__(self, config, dataset):
super(XSimGCL, self).__init__(config, dataset)

self.cl_rate = config['lambda']
self.eps = config['eps']
self.temperature = config['temperature']
self.layer_cl = config['layer_cl']

def forward(self, perturbed=False):
all_embs = self.get_ego_embeddings()
all_embs_cl = all_embs
embeddings_list = []

for layer_idx in range(self.n_layers):
all_embs = self.gcn_conv(all_embs, self.edge_index, self.edge_weight)
if perturbed:
random_noise = torch.rand_like(all_embs, device=all_embs.device)
all_embs = all_embs + torch.sign(all_embs) * F.normalize(random_noise, dim=-1) * self.eps
embeddings_list.append(all_embs)
if layer_idx == self.layer_cl - 1:
all_embs_cl = all_embs
lightgcn_all_embeddings = torch.stack(embeddings_list, dim=1)
lightgcn_all_embeddings = torch.mean(lightgcn_all_embeddings, dim=1)

user_all_embeddings, item_all_embeddings = torch.split(lightgcn_all_embeddings, [self.n_users, self.n_items])
user_all_embeddings_cl, item_all_embeddings_cl = torch.split(all_embs_cl, [self.n_users, self.n_items])
if perturbed:
return user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl
return user_all_embeddings, item_all_embeddings

def calculate_cl_loss(self, x1, x2):
x1, x2 = F.normalize(x1, dim=-1), F.normalize(x2, dim=-1)
pos_score = (x1 * x2).sum(dim=-1)
pos_score = torch.exp(pos_score / self.temperature)
ttl_score = torch.matmul(x1, x2.transpose(0, 1))
ttl_score = torch.exp(ttl_score / self.temperature).sum(dim=1)
return -torch.log(pos_score / ttl_score).mean()

def calculate_loss(self, interaction):
# clear the storage variable when training
if self.restore_user_e is not None or self.restore_item_e is not None:
self.restore_user_e, self.restore_item_e = None, None

user = interaction[self.USER_ID]
pos_item = interaction[self.ITEM_ID]
neg_item = interaction[self.NEG_ITEM_ID]

user_all_embeddings, item_all_embeddings, user_all_embeddings_cl, item_all_embeddings_cl = self.forward(perturbed=True)
u_embeddings = user_all_embeddings[user]
pos_embeddings = item_all_embeddings[pos_item]
neg_embeddings = item_all_embeddings[neg_item]

# calculate BPR Loss
pos_scores = torch.mul(u_embeddings, pos_embeddings).sum(dim=1)
neg_scores = torch.mul(u_embeddings, neg_embeddings).sum(dim=1)
mf_loss = self.mf_loss(pos_scores, neg_scores)

# calculate regularization Loss
u_ego_embeddings = self.user_embedding(user)
pos_ego_embeddings = self.item_embedding(pos_item)
neg_ego_embeddings = self.item_embedding(neg_item)
reg_loss = self.reg_loss(u_ego_embeddings, pos_ego_embeddings, neg_ego_embeddings, require_pow=self.require_pow)

user = torch.unique(interaction[self.USER_ID])
pos_item = torch.unique(interaction[self.ITEM_ID])

# calculate CL Loss
user_cl_loss = self.calculate_cl_loss(user_all_embeddings[user], user_all_embeddings_cl[user])
item_cl_loss = self.calculate_cl_loss(item_all_embeddings[pos_item], item_all_embeddings_cl[pos_item])

return mf_loss, self.reg_weight * reg_loss, self.cl_rate * (user_cl_loss + item_cl_loss)
9 changes: 9 additions & 0 deletions recbole_gnn/properties/model/XSimGCL.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
embedding_size: 64
n_layers: 2
reg_weight: 0.0001

lambda: 0.1
eps: 0.2
temperature: 0.2
layer_cl: 1
require_pow: True
22 changes: 12 additions & 10 deletions results/general/ml-1m.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,17 @@ embedding_size: 64
# Evaluation Results
| Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 |
| -------------------- | --------- | ------ | ------- | ------ | ------------ |
| **BPR** | 0.1776 | 0.4187 | 0.2401 | 0.7199 | 0.1779 |
| **NeuMF** | 0.1651 | 0.4020 | 0.2271 | 0.7029 | 0.1700 |
| **NGCF** | 0.1814 | 0.4354 | 0.2508 | 0.7239 | 0.1850 |
| **LightGCN** | 0.1861 | 0.4388 | 0.2538 | 0.7330 | 0.1863 |
| **SGL** | 0.1889 | 0.4315 | 0.2505 | 0.7392 | 0.1843 |
| **HMLET** | 0.1847 | 0.4297 | 0.2490 | 0.7305 | 0.1836 |
| **NCL** | 0.2021 | 0.4599 | 0.2702 | 0.7565 | 0.1962 |
| **SimGCL** | 0.2029 | 0.4550 | 0.2667 | 0.7640 | 0.1933 |
| Method | Recall@10 | MRR@10 | NDCG@10 | Hit@10 | Precision@10 |
| ------------ | --------- | ------ | ------- | ------ | ------------ |
| **BPR** | 0.1776 | 0.4187 | 0.2401 | 0.7199 | 0.1779 |
| **NeuMF** | 0.1651 | 0.4020 | 0.2271 | 0.7029 | 0.1700 |
| **NGCF** | 0.1814 | 0.4354 | 0.2508 | 0.7239 | 0.1850 |
| **LightGCN** | 0.1861 | 0.4388 | 0.2538 | 0.7330 | 0.1863 |
| **SGL** | 0.1889 | 0.4315 | 0.2505 | 0.7392 | 0.1843 |
| **HMLET** | 0.1847 | 0.4297 | 0.2490 | 0.7305 | 0.1836 |
| **NCL** | 0.2021 | 0.4599 | 0.2702 | 0.7565 | 0.1962 |
| **SimGCL** | 0.2029 | 0.4550 | 0.2667 | 0.7640 | 0.1933 |
| **XSimGCL** | 0.2116 | 0.4638 | 0.2750 | 0.7743 | 0.1987 |
# Hyper-parameters
Expand All @@ -69,3 +70,4 @@ embedding_size: 64
| **HMLET** | learning_rate=0.002<br />n_layers=4<br />activation_function=leakyrelu | learning_rate choice [0.002, 0.001, 0.0005]<br/>n_layers choice [3, 4]<br/>activation_function choice ['elu', 'leakyrelu'] |
| **NCL** | learning_rate=0.002<br />n_layers=3<br />reg_weight=0.0001<br />ssl_temp=0.1<br />ssl_reg=1e-06<br />hyper_layers=1<br />alpha=1.5 | learning_rate choice [0.002]<br/>n_layers choice [3]<br/>reg_weight choice [1e-4]<br/>ssl_temp choice [0.1, 0.05]<br/>ssl_reg choice [1e-7, 1e-6]<br/>hyper_layers choice [1]<br/>alpha choice [1, 0.8, 1.5] |
| **SimGCL** | learning_rate=0.002<br />n_layers=2<br />reg_weight=0.0001<br />temperature=0.05<br />lambda=1e-5<br />eps=0.1 | learning_rate choice [0.002]<br/>n_layers choice [2, 3]<br/>reg_weight choice [1e-4]<br/>temperature choice [0.05, 0.1, 0.2]<br/>lambda choice [1e-5, 1e-6, 1e-7, 0.005, 0.01, 0.05]<br/>eps choice [0.1, 0.2] |
| **XSimGCL** | learning_rate=0.002<br />n_layers=2<br />reg_weight=0.0001<br />temperature=0.2<br />lambda=0.1<br />eps=0.2<br />layer_cl=1 | learning_rate choice [0.002]<br/>n_layers choice [2, 3]<br/>reg_weight choice [1e-4]<br/>temperature choice [0.05, 0.1, 0.2]<br/>lambda choice [1e-5, 1e-6, 1e-7, 1e-4, 0.005, 0.01, 0.05, 0.1]<br/>eps choice [0.1, 0.2]<br/>layer_cl choice [1] |
6 changes: 6 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def test_simgcl(self):
'model': 'SimGCL'
}
quick_test(config_dict)

def test_xsimgcl(self):
config_dict = {
'model': 'XSimGCL'
}
quick_test(config_dict)


class TestSequentialRecommender(unittest.TestCase):
Expand Down

0 comments on commit 7c6783f

Please sign in to comment.