This repository is a jax
implementation of the Denoising Diffusion
Probabilistic Models paper. The model is
trained on the anime face
kaggle
dataset.
The implementation should be simple to follow. The forward/backward diffusion
implementation can be found in the file src/diffusion.py
.
The architecture is the U-ViT taken from this paper. I derived from the usual U-nets because I find U-ViT conceptually simpler and easier to manage (less hyperparameters, less choices overall).
You need python 3.12
and pdm
. Use pdm sync
to download all python dependencies.
You can download the dataset using the
kaggle
cli. Then you can use the
following:
mkdir -p datasets
kaggle datasets download -d splcher/animefacedataset
unzip animefacedataset.zip
mv images datasets/anime-faces
rm animefacedataset.zip
This repository is using hydra
to set the
configurations and wandb
to log training metrics. You can find
the default hyperparameters in the default.yaml
files in the config/
directory.
python3 main.py mode=[online|offline] dataset=[default|small] model=[default|small]
The main model has 15M parameters and has been trained for 30 hours on a laptop RTX 3080. But I suspect that you could get great result with a smaller model as well.
One could implement the score-matching or differential version of the diffusion. For DDPM, a rapid enhancement would be to fix the schedule following the ideas in this paper.
Huge shout out to Stanley H. Chan for the tutorial on diffusion models. This is what motivated me to build a minimal implementation and helped me a lot understanding the equations.