Skip to content

Commit

Permalink
Merge pull request #1885 from BishopLiu/master
Browse files Browse the repository at this point in the history
FEA: Add DiffRec and LDiffRec in general models
  • Loading branch information
zhengbw0324 authored Oct 9, 2023
2 parents 00c018e + c472300 commit 2911096
Show file tree
Hide file tree
Showing 16 changed files with 1,267 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs:
pip install torch-scatter -f https://data.pyg.org/whl/torch-`python -c "import torch;print(torch.__version__)"`.html
pip install setuptools==59.5.0
pip install plotly
pip install kmeans-pytorch
# Use "python -m pytest" instead of "pytest" to fix imports
- name: Test Overall
run: |
Expand Down Expand Up @@ -90,4 +91,4 @@ jobs:
- name: Apply code-format changes
uses: stefanzweifel/git-auto-commit-action@v4
with:
commit_message: Format Python code according to PEP8
commit_message: Format Python code according to PEP8
Binary file added docs/source/asset/diffrec.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/asset/ldiffrec.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Introduction
RecBole is a unified, comprehensive and efficient framework developed based on PyTorch.
It aims to help the researchers to reproduce and develop recommendation models.

In the lastest release, our library includes 86 recommendation algorithms `[Model List]`_, covering four major categories:
In the lastest release, our library includes 88 recommendation algorithms `[Model List]`_, covering four major categories:

- General Recommendation
- Sequential Recommendation
Expand Down
94 changes: 94 additions & 0 deletions docs/source/user_guide/model/general/diffrec.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
DiffRec
===========

Introduction
---------------------

`[paper] <https://dl.acm.org/doi/10.1145/3539618.3591663>`_

**Title:** Diffusion Recommender Model

**Authors:** Wenjie Wang, Yiyan Xu, Fuli Feng, Xinyu Lin, Xiangnan He, Tat-Seng Chua

**Abstract:** Generative models such as Generative Adversarial Networks (GANs) and Variational Auto-Encoders (VAEs) are widely utilized to model the generative process of user interactions. However, they suffer from intrinsic limitations such as the instability of GANs and the restricted representation ability of VAEs. Such limitations hinder the accurate modeling of the complex user interaction generation procedure, such as noisy interactions caused by various interference factors. In light of the impressive advantages of Diffusion Models (DMs) over traditional generative models in image synthesis, we propose a novel Diffusion Recommender Model (named DiffRec) to learn the generative process in a denoising manner. To retain personalized information in user interactions, DiffRec reduces the added noises and avoids corrupting users’ interactions into pure noises like in image synthesis. In addition, we extend traditional DMs to tackle the unique challenges in recommendation: high resource costs for large-scale item prediction and temporal shifts of user preference. To this end, we propose two extensions of DiffRec: L-DiffRec clusters items for dimension compression and conducts the diffusion processes in the latent space; and T-DiffRec reweights user interactions based on the interaction timestamps to encode temporal information. We conduct extensive experiments on three datasets under multiple settings (e.g., clean training, noisy training, and temporal training). The empirical results validate the superiority of DiffRec with two extensions over competitive baselines.

.. image:: ../../../asset/diffrec.png
:width: 500
:align: center

Running with RecBole
-------------------------

**Model Hyper-Parameters:**

- ``noise_schedule (str)`` : The schedule for noise generating: ['linear', 'linear-var', 'cosine', 'binomial']. Defaults to ``'linear'``.
- ``noise_scale (int)`` : The scale for noise generating. Defaults to ``0.001``.
- ``noise_min (int)`` : Noise lower bound for noise generating. Defaults to ``0.0005``.
- ``noise_max (int)`` : 0.005 Noise upper bound for noise generating. Defaults to ``0.005``.
- ``sampling_noise (bool)`` : Whether to use sampling noise. Defaults to ``False``.
- ``sampling_steps (int)`` : Steps of the forward process during inference. Defaults to ``0``.
- ``reweight (bool)`` : Assign different weight to different timestep or not. Defaults to ``True``.
- ``mean_type (str)`` : MeanType for diffusion: ['x0', 'eps']. Defaults to ``'x0'``.
- ``steps (int)`` : Diffusion steps. Defaults to ``5``.
- ``history_num_per_term (int)`` : The number of history items needed to calculate loss weight. Defaults to ``10``.
- ``beta_fixed (bool)`` : Whether to fix the variance of the first step to prevent overfitting. Defaults to ``True``.
- ``dims_dnn (list of int)`` : The dims for the DNN. Defaults to ``[300]``.
- ``embedding_size (int)`` : Timestep embedding size. Defaults to ``10``.
- ``mlp_act_func (str)`` : Activation function for MLP. Defaults to ``'tanh'``.
- ``time-aware (bool)`` : T-DiffRec or not. Defaults to ``False``.
- ``w_max (int)`` : The upper bound of the time-aware interaction weight. Defaults to ``1``.
- ``w_min (int)`` : The lower bound of the time-aware interaction weight. Defaults to ``0.1``.


**A Running Example:**

Write the following code to a python file, such as `run.py`

.. code:: python
from recbole.quick_start import run_recbole
run_recbole(model='DiffRec', dataset='ml-100k')
And then:

.. code:: bash
python run.py
**Notes:**

- ``w_max`` and ``w_min`` are unused when ``time-aware`` is False.

Tuning Hyper Parameters
-------------------------

If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``.

.. code:: bash
learning_rate choice [1e-3,1e-4,1e-5]
dims_dnn choice ['[300]','[200,600]','[1000]']
steps choice [2,5,10,50]
noice_scale choice [0,1e-5,1e-4,1e-3,1e-2,1e-1]
noice_min choice [5e-4,1e-3,5e-3]
noice_max choice [5e-3,1e-2]
w_min choice [0.1,0.2,0.3]
Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model.

Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning:

.. code:: bash
python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test
For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`.


If you want to change parameters, dataset or evaluation settings, take a look at

- :doc:`../../../user_guide/config_settings`
- :doc:`../../../user_guide/data_intro`
- :doc:`../../../user_guide/train_eval_intro`
- :doc:`../../../user_guide/usage`
106 changes: 106 additions & 0 deletions docs/source/user_guide/model/general/ldiffrec.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
LDiffRec
===========

Introduction
---------------------

`[paper] <https://dl.acm.org/doi/10.1145/3539618.3591663>`_

**Title:** Diffusion Recommender Model

**Authors:** Wenjie Wang, Yiyan Xu, Fuli Feng, Xinyu Lin, Xiangnan He, Tat-Seng Chua

**Abstract:** Generative models such as Generative Adversarial Networks (GANs) and Variational Auto-Encoders (VAEs) are widely utilized to model the generative process of user interactions. However, they suffer from intrinsic limitations such as the instability of GANs and the restricted representation ability of VAEs. Such limitations hinder the accurate modeling of the complex user interaction generation procedure, such as noisy interactions caused by various interference factors. In light of the impressive advantages of Diffusion Models (DMs) over traditional generative models in image synthesis, we propose a novel Diffusion Recommender Model (named DiffRec) to learn the generative process in a denoising manner. To retain personalized information in user interactions, DiffRec reduces the added noises and avoids corrupting users’ interactions into pure noises like in image synthesis. In addition, we extend traditional DMs to tackle the unique challenges in recommendation: high resource costs for large-scale item prediction and temporal shifts of user preference. To this end, we propose two extensions of DiffRec: L-DiffRec clusters items for dimension compression and conducts the diffusion processes in the latent space; and T-DiffRec reweights user interactions based on the interaction timestamps to encode temporal information. We conduct extensive experiments on three datasets under multiple settings (e.g., clean training, noisy training, and temporal training). The empirical results validate the superiority of DiffRec with two extensions over competitive baselines.

.. image:: ../../../asset/ldiffrec.png
:width: 500
:align: center

Running with RecBole
-------------------------

**Model Hyper-Parameters:**

- ``noise_schedule (str)`` : The schedule for noise generating: [linear, linear-var, cosine, binomial]. Defaults to ``'linear'``.
- ``noise_scale (int)`` : The scale for noise generating. Defaults to ``0.1``.
- ``noise_min (int)`` : Noise lower bound for noise generating. Defaults to ``0.001``.
- ``noise_max (int)`` : 0.005 Noise upper bound for noise generating. Defaults to ``0.005``.
- ``sampling_noise (bool)`` : Whether to use sampling noise. Defaults to ``False``.
- ``sampling_steps (int)`` : Steps of the forward process during inference. Defaults to ``0``.
- ``reweight (bool)`` : Assign different weight to different timestep or not. Defaults to ``True``.
- ``mean_type (str)`` : MeanType for diffusion: ['x0', 'eps']. Defaults to ``'x0'``.
- ``steps (int)`` : Diffusion steps. Defaults to ``5``.
- ``history_num_per_term (int)`` : The number of history items needed to calculate loss weight. Defaults to ``10``.
- ``beta_fixed (bool)`` : Whether to fix the variance of the first step to prevent overfitting. Defaults to ``True``.
- ``dims_dnn (list of int)`` : The dims for the DNN. Defaults to ``[300]``.
- ``embedding_size (int)`` : Timestep embedding size. Defaults to ``10``.
- ``mlp_act_func (str)`` : Activation function for MLP. Defaults to ``'tanh'``.
- ``time-aware (bool)`` : LT-DiffRec or not. Defaults to ``False``.
- ``w_max (int)`` : The upper bound of the time-aware interaction weight. Defaults to ``1``.
- ``w_min (int)`` : The lower bound of the time-aware interaction weight. Defaults to ``0.1``.
- ``n_cate (int)`` : Category num of items. Defaults to ``1``.
- ``reparam (bool) `` : Autoencoder with variational inference or not. Defaults to ``True``.
- ``in_dims (list of int)`` : The dims for the encoder. Defaults to ``[300]``.
- ``out_dims (list of int)`` : The hidden dims for the decoder. Defaults to ``[]``.
- ``ae_act_func (str)`` : Activation function for AutoEncoder. Defaults to ``'tanh'``.
- ``lamda (float)`` : Hyper-parameter of multinomial log-likelihood for AE. Defaults to ``0.03``.
- ``anneal_cap (float)`` : The upper bound of the annealing weight. Defaults to ``0.005``.
- ``anneal_steps (int)`` : The steps of annealing. Defaults to ``1000``.
- ``vae_anneal_cap (float)`` : The upper bound of the VAE annealing weight. Defaults to ``0.3``.
- ``vae_anneal_steps (int)`` : The steps of VAE annealing. Defaults to ``200``.


**A Running Example:**

Write the following code to a python file, such as `run.py`

.. code:: python
from recbole.quick_start import run_recbole
run_recbole(model='LDiffRec', dataset='ml-100k')
And then:

.. code:: bash
python run.py
**Notes:**

- ``w_max`` and ``w_min`` are unused when ``time-aware`` is False.

- The item embedding file is needed if ``n_cate`` is greater than 1.

Tuning Hyper Parameters
-------------------------

If you want to use ``HyperTuning`` to tune hyper parameters of this model, you can copy the following settings and name it as ``hyper.test``.

.. code:: bash
learning_rate choice [1e-3,1e-4,1e-5]
dims_dnn choice ['[300]','[200,600]','[1000]']
steps choice [2,5,10,50]
noice_scale choice [0,1e-5,1e-4,1e-3,1e-2,1e-1]
noice_min choice [5e-4,1e-3,5e-3]
noice_max choice [5e-3,1e-2]
w_min choice [0.1,0.2,0.3]
Note that we just provide these hyper parameter ranges for reference only, and we can not guarantee that they are the optimal range of this model.

Then, with the source code of RecBole (you can download it from GitHub), you can run the ``run_hyper.py`` to tuning:

.. code:: bash
python run_hyper.py --model=[model_name] --dataset=[dataset_name] --config_files=[config_files_path] --params_file=hyper.test
For more details about Parameter Tuning, refer to :doc:`../../../user_guide/usage/parameter_tuning`.


If you want to change parameters, dataset or evaluation settings, take a look at

- :doc:`../../../user_guide/config_settings`
- :doc:`../../../user_guide/data_intro`
- :doc:`../../../user_guide/train_eval_intro`
- :doc:`../../../user_guide/usage`
4 changes: 3 additions & 1 deletion docs/source/user_guide/model_intro.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Model Introduction
=====================
We implement 86 recommendation models covering general recommendation, sequential recommendation,
We implement 88 recommendation models covering general recommendation, sequential recommendation,
context-aware recommendation and knowledge-based recommendation. A brief introduction to these models are as follows:


Expand Down Expand Up @@ -42,6 +42,8 @@ task of top-n recommendation. All the collaborative filter(CF) based models are
model/general/nceplrec
model/general/simplex
model/general/ncl
model/general/diffrec
model/general/ldiffrec


Context-aware Recommendation
Expand Down
2 changes: 2 additions & 0 deletions recbole/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ def get_dataloader(config, phase: Literal["train", "valid", "test", "evaluation"
"ENMF": _get_AE_dataloader,
"RaCT": _get_AE_dataloader,
"RecVAE": _get_AE_dataloader,
"DiffRec": _get_AE_dataloader,
"LDiffRec": _get_AE_dataloader,
}

if config["model"] in register_table:
Expand Down
2 changes: 2 additions & 0 deletions recbole/model/general_recommender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@
from recbole.model.general_recommender.sgl import SGL
from recbole.model.general_recommender.admmslim import ADMMSLIM
from recbole.model.general_recommender.simplex import SimpleX
from recbole.model.general_recommender.diffrec import DiffRec
from recbole.model.general_recommender.ldiffrec import LDiffRec
Loading

0 comments on commit 2911096

Please sign in to comment.