Code repository for the paper Think While You Generate: Discrete Diffusion with Planned Denoising, by Sulin Liu, Juno Nam, Andrew Campbell, Hannes Stärk, Yilun Xu, Tommi Jaakkola, Rafael Gómez-Bombarelli.
Sampling process of DDPD: A planner is first used to determine which positions are most likely to be noise and should be denoised next. The denoiser is then applied to the selected positions conditioned on all the input tokens.
Training objectives of the planner and denoiser: Cross-entropy loss for predicting the binary mask of noise tokens for the planner and cross-entropy loss for predicting the original token values for the denoiser.
Package requirements are listed in ddpd_text.yml
. Mamba is recommended for faster installation.
conda env create -f ddpd_text.yml
Our pretrained models can be downloaed at this link.
First we download the text8 data. Set the DATA_DIR
variable within the text8/data/download.sh
script to the location of this repository's data/text8 directory. Then run
bash text8/data/download.sh
Then we pre-process the downloaded data. Set the text8_file_path
variable within the text8/data/prepare.py
script to the location of the downloaded data. Then run
python text8/data/prepare.py
torchrun --standalone --nproc_per_node=4 train_denoiser.py text8/config/train_denoiser.py --batch_size=512 --gradient_accumulation_steps=4 --resume_dir=None --wandb_run_name='ddpd_denoiser_mask' --model_type='ddpd_denoiser_mask'
On a single node with 4 GPUs of 80GB memory, batch_size
and gradient_accumulation_steps
can be adjusted to fit in smaller GPU memory.
torchrun --standalone --nproc_per_node=4 train_planner.py text8/config/train_planner.py --batch_size=512 --gradient_accumulation_steps=4 --resume_dir=None --wandb_run_name='ddpd_planner' --model_type='ddpd_planner'
python sample_text8.py text8/config/sample.py
To replicate the results of DDPD-MaskD and DDPD-UniD in the following figure,
bash text8/scripts/generate_samples_ddpd_maskD.sh
bash text8/scripts/generate_samples_ddpd_uniD.sh
bash text8/scripts/evaluate_samples_ddpd_maskD.sh
bash text8/scripts/evaluate_samples_ddpd_uniD.sh
Results on text8 conditional generation task:
Results on OpenWebText language modeling unconditional generation task:
Results on ImageNet 256x256 token generation task: Increasing number of steps:
@article{liu2024ddpd,
title={Think While You Generate: Discrete Diffusion with Planned Denoising},
author={Liu, Sulin and Nam, Juno and Campbell, Andrew and Stärk, Hannes and Xu, Yilun and Jaakkola, Tommi and Gómez-Bombarelli, Rafael},
journal={arXiv preprint arXiv:2410.06264},
year={2024}
}
This repo is built on top of nanoGPT, discrete_flow_models