Skip to content

Basic Denoising Diffusion Probabilistic Model image generator implemented in PyTorch

License

Notifications You must be signed in to change notification settings

LarsDu/DiffuMon

Repository files navigation

DiffuMon

Basic Denoising Diffusion Probabilistic Model image generator implemented in PyTorch.

Reproduces Denoising Diffusion Probabilistic Models.

Developed as an educational project, with the aim of having a simpler PyTorch implementation and development setup than other DDPM implementations available. Small and lean enough to train on a commodity GPU (in this case my Geforce 4070 Ti).

The basic idea is to train a model to learn how to denoise images. Images are generated by using this trained model to iteratively remove noise from a random noise image until a coherent image forms.

Two pretrained models are provided in the checkpoints/ directory for Fashion MNIST and a 11k Pokemon dataset.

Features

  • Reproducible environment with rye. Get setup with a single command.
  • Automatic dataset download and preprocessing for certain preloaded datasets.
  • Example notebook for sampling and gif generation.
  • Train on your own dataset by providing image files in a --data-dir directory.

Example Generations

Fashion MNIST Fashion MNIST sample generations

Pokemon 11k Pokemon 11k sample generations

NOTE: With small images, and high training epochs, the model likely overfits and gains the capability to memorize training samples

Denoising in action

Denoising fashion MNIST

Denoising Pokemon 11k

Getting started

Setting up environment

This repo uses rye as the package/environment manager. Make sure to install it before proceeding.

The following command will install packages and setup a virtual environment

# Install packages
rye sync

# Activate virtual enviornment
. .venv/bin/activate

Access the entrypoint

Once installed, the model can be trained and used via the diffumon command

diffumon --help

Train a model

diffumon train --help

Train a fashion MNIST model

diffumon train --preloaded fashion_mnist --num-epochs 100 --checkpoint-path checkpoints/fashion_mnist_100epochs.pth

Train a Pokemon Generative Model on the 11k Pokemon dataset (downscaled to 64x64 pixels)

diffumon train --preloaded pokemon_11k --num-epochs 800 --img-dim 64 -- batch-size 64 --checkpoint-path checkpoints/pokemon_11k_800epochs_64dim.pth

Train a model on a dataset of your choice

diffumon train --data-dir /path/to/dataset --num-epochs 100 --checkpoint-path checkpoints/my_dataset_100_epochs.pth

Where /path/to/dataset should have a directory structure like the following:

/path/to/dataset/
    train/
      class_0/
        img_0.png
        img_1.png
    test/
      class_0/
        img_0.png
        img_1.png

Generate samples

diffumon sample --help

Generate samples from the trained fashion MNIST model

diffumon sample --checkpoint-path checkpoints/fashion_mnist_100epochs.pth --num-samples 32 --output-dir samples/fashion_mnist_100epochs

Generate samples from the trained Pokemon Generative Model

diffumon sample --checkpoint-path checkpoints/pokemon_11k_800epochs_32dim.pth --num-samples 32 --output-dir samples/pokemon_11k_800epochs_32dim

Useful resources

Developer notes

black, ruff, isort, and pre-commit should come as preinstalled dev developer packages in the virtual environment.

It's strongly recommended to install pre-commit hooks to ensure code consistency and quality which will automatically run formatters (but not linters) before each commit.

pre-commit install

Jupyter notebooks

There are also example notebook(s) in the notebooks/ directory.

Make sure to install the diffumon kernel in Jupyter to run the notebooks.

python -m ipykernel install --user --name diffumon --display-name "Python Diffumon"

TODOs

  • Add support for more preloaded datasets
  • Add smarter periodic checkpointing
  • Add logging
  • Improve learning rate scheduling
  • Add DDIM (Denoising Diffusion Implicit Models) support
  • Add (Hydra-based?) preconfigured training options

About

Basic Denoising Diffusion Probabilistic Model image generator implemented in PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages