Skip to content

joseph-nagel/diffusion-demo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

69 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch denoising diffusion demo

The repository contains a simple PyTorch-based demonstration of denoising diffusion models. It just aims at providing a first understanding of this generative modeling approach.

A short theoretical intro to standard DDPMs can be found here. DDIMs for accelerated sampling are discussed in the companion notebook. Two example applications establish a small experimentation playground. They are prepared in such a way that they can be easily modified and extended.

Notebooks

Swiss roll

As a first example, a generative DDPM is trained on a 2D Swiss roll distribution. The main training script can be called to that end with a config file that allows one to adjust the problem setup and model definition:

python scripts/main.py fit --config config/swissroll.yaml

After the training has finished, the final model can be tested and analyzed in this notebook.

For monitoring the experiment, one can locally run a TensorBoard server by tensorboard --logdir run/swissroll/. It can be reached under localhost:6006 per default in your browser. As an alternative, one may use MLfLow for managing experiments. In this case, one can launch the training with the appropriate settings and set up a tracking server by mlflow server --backend-store-uri file:./run/mlruns/. It can then be reached under localhost:5000.

Forward process diffusing data into noise

Reverse process generating data from noise

MNIST

The second application is based on the MNIST dataset. Here, one can construct a DDPM that is either unconditioned (generates randomly) or conditioned on the class (generates controllably). Such models generating images of handwritten digits can be learned by running the main script in the following ways:

python scripts/main.py fit --config config/mnist_uncond.yaml
python scripts/main.py fit --config config/mnist_cond.yaml

Two dedicated notebooks here and here are provided in order to test the unconditional and the conditional model after training, respectively.

Forward process diffusing data into noise

Reverse process generating data from noise