Skip to content

Official implementation for Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models (ICML 2022), and a reimplementation of Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models (ICLR 2022)

Notifications You must be signed in to change notification settings

baofff/Extended-Analytic-DPM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Extended Analytic-DPM

  • This is the official implementation for Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models (Accepted in ICML 2022). It extends Analytic-DPM under the following two settings:

    • The reverse process adpots complicated covariance matrices dependent to states, instead of simple scalar variances (which motivates the SN-DPM in the paper).
    • The score-based model has some error w.r.t. the exact score function (which motivates NPR-DPM in the paper).
  • This codebase also reimplements Analytic-DPM and reproduces its most results. The pretrained DPMs used in the Analytic-DPM paper are provided here, and have already been converted to a format that can be directly used for this codebase. We also additionally applies Analytic-DPM to score-based SDE.

  • Models and FID statistics are available here to reproduce results in this paper.

Dependencies

The codebase is based on pytorch. The dependencies are listed below.

pip install pytorch>=1.9.0 torchvision ml-collections ninja tensorboard

Basic usage

The basic usage for training is

python run_train.py --pretrained_path path/to/pretrained_dpm --dataset dataset --workspace path/to/working_directory $train_hparams
  • pretrained_path is the path to a pretrained diffusion probabilistic model (DPM). Here provide all pretrained DPMs used in this work.
  • dataset represents the training dataset, one of <cifar10|celeba64|imagenet64|lsun_bedroom>.
  • workspace is the place to put training outputs, e.g., logs and middle checkpoints.
  • train_hparams specify other hyperparameters used in training. Here lists train_hparams for all models.

The basic usage for evaluation is

python run_eval.py --pretrained_path path/to/evaluated_model --dataset dataset --workspace path/to/working_directory \
    --phase phase --sample_steps sample_steps --batch_size batch_size --method method $eval_hparams
  • pretrained_path is the path to a model to evaluate. Here provide all models evaluated in this work.
  • dataset represents the dataset the model is trained on, one of <cifar10|celeba64|imagenet64|lsun_bedroom>.
  • workspace is the place to put evaluation outputs, e.g., logs, samples and bpd values.
  • phase specifies running sampling or likelihood evaluation, one of <sample4test|nll4test>.
  • sample_steps is the number of steps to run during inference, the samller this value the faster the inference.
  • batch_size is the batch size, e.g., 500.
  • method specifies the type of the model, one of:
    • pred_eps the original DPM (i.e., a noise prediction model) with discrete timesteps
    • pred_eps_eps2_pretrained the SN-DPM with discrete timesteps
    • pred_eps_epsc_pretrained the NPR-DPM with discrete timesteps
    • pred_eps_ct2dt the original (i.e., a noise prediction model) with continuous timesteps (i.e., a score-based SDE)
    • pred_eps_eps2_pretrained_ct2dt the SN-DPM with continuous timesteps
    • pred_eps_epsc_pretrained_ct2dt the NPR-DPM with continuous timesteps
  • eval_hparams specifies other optional hyperparameters used in evaluation.
  • Here lists method and eval_hparams for NPR/SN-DPM and Analytic-DPM results in this paper.

Models and FID statistics

Here is the list of NPR-DPMs and SN-DPMs trained in this work. These models only train an additional prediction head in the last layer of a pretrained diffusion probabilistic model (DPM).

NPR/SN-DPM Pretrained DPM train_hparams
CIFAR10 (LS), NPR-DPM CIFAR10 (LS) "--method pred_eps_epsc_pretrained"
CIFAR10 (LS), SN-DPM CIFAR10 (LS) "--method pred_eps_eps2_pretrained"
CIFAR10 (CS), NPR-DPM CIFAR10 (CS) "--method pred_eps_epsc_pretrained --schedule cosine_1000"
CIFAR10 (CS), SN-DPM CIFAR10 (CS) "--method pred_eps_eps2_pretrained --schedule cosine_1000"
CIFAR10 (VP SDE), NPR-DPM CIFAR10 (VP SDE) "--method pred_eps_epsc_pretrained_ct --sde vpsde"
CIFAR10 (VP SDE), SN-DPM CIFAR10 (VP SDE) "--method pred_eps_eps2_pretrained_ct --sde vpsde"
CelebA 64x64, NPR-DPM CelebA 64x64 "--method pred_eps_epsc_pretrained"
CelebA 64x64, SN-DPM CelebA 64x64 "--method pred_eps_eps2_pretrained"
ImageNet 64x64, NPR-DPM ImageNet 64x64 "--method pred_eps_epsc_pretrained --mode simple"
ImageNet 64x64, SN-DPM ImageNet 64x64 "--method pred_eps_eps2_pretrained --mode complex"
LSUN Bedroom, NPR-DPM LSUN Bedroom "--method pred_eps_epsc_pretrained --mode simple"
LSUN Bedroom, SN-DPM LSUN Bedroom "--method pred_eps_eps2_pretrained --mode complex"

Here is the list of pretrained DPMs, collected from prior works. They are converted to a format that can be directly used for this codebase.

Pretrained DPM Expected mean squared norm (ms_eps)
(Used in Analytic-DPM)
From
CIFAR10 (LS) Link Analytic-DPM
CIFAR10 (CS) Link Analytic-DPM
CIFAR10 (VP SDE) Link score-sde
CelebA 64x64 Link DDIM
ImageNet 64x64 Link Improved DDPM
LSUN Bedroom Link pytorch_diffusion

This link provides precalculated FID statistics on CIFAR10, CelebA 64x64, ImageNet 64x64 and LSUN Bedroom. They are computed following Appendix F.2 in Analytic-DPM.

Evaluation Hyperparamters for NPR/SN-DPM and Analytic-DPM

Note: Analytic-DPM needs to precalculate the expected mean squared norm of noise prediction model (ms_eps), which is provided here. Specify their path by --ms_eps_path.

  • Sampling experiments on CIFAR10 (LS) or CelebA 64x64, Table 1 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2"
SN-DDPM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2"
Analytic-DDPM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --ms_eps_path ms_eps_path"
NPR-DDIM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0"
SN-DDIM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0"
Analytic-DDIM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --ms_eps_path ms_eps_path"
  • Sampling experiments on CIFAR10 (CS), Table 1 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000"
SN-DDPM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000"
Analytic-DDPM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000 --ms_eps_path ms_eps_path"
NPR-DDIM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000"
SN-DDIM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000"
Analytic-DDIM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000 --ms_eps_path ms_eps_path"
  • Sampling experiments on CIFAR10 (VP SDE), Table 1 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000"
SN-DDPM pred_eps_eps2_pretrained_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000"
Analytic-DDPM pred_eps_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000 --ms_eps_path ms_eps_path"
NPR-DDIM pred_eps_epsc_pretrained_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000"
SN-DDIM pred_eps_eps2_pretrained_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000"
Analytic-DDIM pred_eps_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000 --ms_eps_path ms_eps_path"
  • Sampling experiments on ImageNet 64x64, Table 1 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode simple"
SN-DDPM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode complex"
Analytic-DDPM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --ms_eps_path ms_eps_path"
NPR-DDIM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode simple"
SN-DDIM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode complex"
Analytic-DDIM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --ms_eps_path ms_eps_path"
  • Likelihood experiments on CIFAR10 (LS) or CelebA 64x64, Table 3 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal"
Analytic-DDPM pred_eps "--rev_var_type optimal --ms_eps_path ms_eps_path"
  • Likelihood experiments on CIFAR10 (CS), Table 3 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --schedule cosine_1000"
Analytic-DDPM pred_eps "--rev_var_type optimal --schedule cosine_1000 --ms_eps_path ms_eps_path"
  • Likelihood experiments on ImageNet 64x64, Table 3 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --mode simple"
Analytic-DDPM pred_eps "--rev_var_type optimal --ms_eps_path ms_eps_path"

This implementation is based on / inspired by

About

Official implementation for Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models (ICML 2022), and a reimplementation of Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models (ICLR 2022)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published