Skip to content

PyTorch implementation for "Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes" (ICML 2024).

Notifications You must be signed in to change notification settings

harryjo97/riemannian-diffusion-mixture-torch

Repository files navigation

Riemannian Diffusion Mixture

This repo contains a PyTorch implementation for the paper Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes.

We provide official code repo for JAX implementation in riemannian-diffusion-mixture.

Why Riemannian Diffusion Mixture?

  • Simple design of the generative process as a mixture of Riemannian bridge processes, which does not require heat kernel estimation as previous denoising approach.
  • Geometrical interpretation for the mixture process as the weighted mean of tangent directions on manifolds
  • Scales to higher dimensions with significantly faster training compared to previous diffusion models.

Dependencies

Create an environment with Python 3.9.0, and Pytorch 2.0.0. Install requirements with the following command:

pip install -r requirements.txt
conda install -c conda-forge cartopy python-kaleido

Manifolds

Following manifolds are supported in this repo:

  • Euclidean
  • Hypersphere
  • Torus
  • Hyperboloid
  • Triangular mesh
  • Special orthogonal group

To implement new manifolds, add python files that define the geometry of the manifold in /geomstats/geometry.

Please refer to geomstats/geometry for examples.

Running Experiments

This repo supports experiments on the following datasets:

  • Protein datasets: General, Glycine, Proline, and Pre-Pro, and RNA.
  • High-dimensional tori

Please refer to riemannian-diffusion-mixture for running expreiments on earth and climate science datasets, triangular mesh datasets, and hyperboloid datasets.

1. Dataset preparations

For experiment on Protein datasets, create .tsv file in /data/top500 directory with the following command:

cd data/top500
bash batch_download.sh -f list_file.txt -p
python get_torsion_angle.py

For experiment on RNA dataset, create .tsv file in /data/rna directory with the following command:

cd data/rna
bash batch_download.sh -f list_file.txt -p
python get_torsion_angles.py

2. Configurations

The configurations are provided in the config/ directory in YAML format.

3. Experiments

CUDA_VISIBLE_DEVICES=0 python main.py -m \
    experiment=<exp> \
    seed=0,1,2,3,4 \
    n_jobs=5 \

where <exp> is one of the experiments in config/experiment/*.yaml

For example,

CUDA_VISIBLE_DEVICES=0 python main.py -m \
    experiment=rna \
    seed=0,1,2,3,4 \
    n_jobs=5 \

To run experiments on high-dimensional tori, use experiment=htori with n=$DIM where $DIM denotes the dimesion of the tori.

Citation

If you found the provided code with our paper useful in your work, we kindly request that you cite our work.

@inproceedings{jo2024riemannian,
  author    = {Jaehyeong Jo and
               Sung Ju Hwang},
  title     = {Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes},
  booktitle = {International Conference on Machine Learning},
  year      = {2024},
}

Acknowledgments

Our code builds upon geomstats. We thank Riemannian Score-Based Generative Modelling and Riemmanian Flow Matching for their works.

About

PyTorch implementation for "Generative Modeling on Manifolds Through Mixture of Riemannian Diffusion Processes" (ICML 2024).

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published