Official implementation for paper: "Reducing Spatial Fitting Error in Distillation of Denoising Diffusion Models" (AAAI 2024)
We propose Spatial Fitting-Error Reduction Distillation model (
A suitable conda environment named SFERD can be created and activated with:
conda env create -f environment.yml
conda activate SFERD
We provide the main core code implementation of the SFERD model, which includes network design for the teacher model with attention guidance (./unet/teacher_unet.py
), the student model with semantic gradient predictor (./unet/student_unet.py
), the implementation of the diffusion distillation training process (./diffusion/gaussian_diffusion.py
), trainer defination file (./diffusion/train_utils.py
), the main file for distillation training (train_diffusion_distillation.py
) and sampling(sample.py
).
Specially, the main work of ./unet/teacher_unet.py
is extracting the attention map of the middle or decoder blocks in diffusion model. The main work of ./unet/student_unet.py
is adding semantic encoder module, gradient predictor module and latent diffusion module, and futher incorporating them into training with the trained distillation student model. The main work of ./diffusion/gaussian_diffusion.py
is achieving attention guidance method based on teacher model, reformulating training loss objective with semantic gradient predictor, training diffusion distillation model, training latent diffusion and applying necessary diffusion process(including inference, forward, noise schedule setting.)
The detailed code will come soon !!!
This implementation is based on the repo from openai/guided-diffusion and openai/consistency_models.