Deep Attentive VAE is the first work that proposes attention mechanisms to build more expressive variational distributions in deep probabilistic models. It achieves state-of-the-art log-likelihoods while using fewer latent layers and requiring less training time than prior hierarchical VAEs.
See also:
Deep Attentive Vae is built in Python 3.7 using Tensorflow 2.3.0. Use the following command to install the requirements:
pip install -r requirements.txt
Below, we provide the training command for several model configurations on public benchmarks. Detailed documentation can be found in src/run.py and src/model/util.py.
16-layer CIFAR-10 Attentive VAE for 900 epochs
- Number of trainable Parameters: 118.97M
python3 run.py --mode=train --dataset=cifar10 --train_log_dir=../cifar10/ --train_log_subdir=16layer_900_epochs \
--train_hparams=num_gpus=4,batch_size=40,epochs=900,learning_rate=0.01,learning_rate_schedule=cosine_restart_warmup,lr_first_decay_epochs=300 \
--hparams=layer_latent_shape=[16,16,20],num_layers=16,data_distribution=discretized_logistic_mixture \
--encoder_attention_hparams=key_dim=20,query_dim=20,use_layer_norm=True \
--decoder_attention_hparams=key_dim=20,query_dim=20,use_layer_norm=True \
--decoder_hparams=scale=[0,0],use_nonlocal=True,nonlocop_hparams=[key_dim=32,query_dim=32],num_blocks=2,num_filters=128 \
--encoder_hparams=scale=[0,0],use_nonlocal=True,nonlocop_hparams=[key_dim=32,query_dim=32],num_blocks=2,num_filters=128 \
--preproc_encoder_hparams=scale=[0,-2],use_nonlocal=True,nonlocop_hparams=[key_dim=32,query_dim=32],num_blocks=2,num_filters=128 \
--postproc_decoder_hparams=scale=[2,0],num_blocks=2,num_filters=128 \
--posterior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-5,noise_stddev=0.001 \
--prior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-1.0,noise_stddev=0.001
16-layer CIFAR-10 Attentive VAE for 400 epochs
- Number of trainable Parameters: 118.97M
python3 run.py --mode=train --dataset=cifar10 --train_log_dir=../cifar10/ --train_log_subdir=16layer_400_epochs \
--train_hparams=num_gpus=4,batch_size=40,epochs=400,learning_rate=0.01,learning_rate_schedule=cosine_warmup \
--hparams=layer_latent_shape=[16,16,20],num_layers=16,data_distribution=discretized_logistic_mixture \
--encoder_attention_hparams=key_dim=20,query_dim=20,use_layer_norm=True \
--decoder_attention_hparams=key_dim=20,query_dim=20,use_layer_norm=True \
--decoder_hparams=scale=[0,0],use_nonlocal=True,nonlocop_hparams=[key_dim=32,query_dim=32],num_blocks=2,num_filters=128 \
--encoder_hparams=scale=[0,0],use_nonlocal=True,nonlocop_hparams=[key_dim=32,query_dim=32],num_blocks=2,num_filters=128 \
--preproc_encoder_hparams=scale=[0,-2],use_nonlocal=True,nonlocop_hparams=[key_dim=32,query_dim=32],num_blocks=2,num_filters=128 \
--postproc_decoder_hparams=scale=[2,0],num_blocks=2,num_filters=128 \
--posterior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-5,noise_stddev=0.001 \
--prior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-1.0,noise_stddev=0.001
15-layer OMNIGLOT Attentive VAE for 400 epochs
- Number of trainable Parameters: 7.87M
python3 run.py --mode=train --dataset=omniglot --dataset_path=../omniglot_data/chardata.mat --train_log_dir=../omniglot/ --train_log_subdir=15layer_400_epochs \
--train_hparams=num_gpus=2,batch_size=16,epochs=400,learning_rate=0.01,learning_rate_schedule=cosine_warmup \
--hparams=layer_latent_shape=[16,16,20],num_layers=15,num_proc_blocks=1,data_distribution=bernoulli \
--decoder_attention_hparams=key_dim=8,query_dim=8,use_layer_norm=True \
--encoder_attention_hparams=key_dim=8,query_dim=8,use_layer_norm=False \
--merge_encoder_hparams=use_nonlocal=True,nonlocop_hparams=[key_dim=8,query_dim=8] \
--merge_decoder_hparams=use_nonlocal=True,nonlocop_hparams=[key_dim=8,query_dim=8] \
--decoder_hparams=scale=[0,0],num_blocks=2,num_filters=32 \
--encoder_hparams=scale=[0,0],num_blocks=2,num_filters=32 \
--preproc_encoder_hparams=scale=[0,0,-2],num_nodes=2,num_blocks=3,num_filters=32 \
--postproc_decoder_hparams=scale=[2,0,0],num_nodes=2,use_nonlocal=True,nonlocop_hparams=[key_dim=8,query_dim=8],num_blocks=3,num_filters=32 \
--posterior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-5 \
--prior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-5
16-layer CIFAR-10, NVAE for 400 epochs
- Number of trainable Parameters: 79.21M
python3 run.py --mode=train --dataset=cifar10 --train_log_dir=../cifar10/ --train_log_subdir=nvae_400epochs --train_hparams=num_gpus=4,batch_size=40,epochs=400,learning_rate=0.01,learning_rate_schedule=cosine_warmup --hparams=layer_latent_shape=[16,16,20],num_layers=16,data_distribution=discretized_logistic_mixture --decoder_hparams=scale=[0,0],num_blocks=2,num_filters=128 --encoder_hparams=scale=[0,0],num_blocks=2,num_filters=128 --preproc_encoder_hparams=scale=[0,-2],num_blocks=2,num_filters=128 --postproc_decoder_hparams=scale=[2,0],num_blocks=2,num_filters=128 --posterior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-5 --prior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-5.0
The outputs include:
-
The learning curves and the performance (in terms of NELBO) on the test data are provided in the 'keras_trainer_log.csv' file.
- val_loss: the total loss of the training objective, including the regularization penalty.
- val_loss_1: the negative conditional likelihood term of the NELBO.
- val_loss_2: the KL regularization penalty of the NELBO.
-
The weights of the learned model are in final_model.h5.
-
The command for training the model from scratch can be found in train.sh.
-
The command for evaluating the model can be found in eval_command.sh.
The training command outputs the evaluation script in the file eval_command.sh. Some sample evaluation scripts are provided below. Please refer to src/util/eval.py for detailed documentation.
The evaluation script produces in the designated folder (eval_log_dir) the following files:
- if hparams.compute_logp=True: marginal loglikelihood computed by importance sampling.
- sample_image_x.png: samples images drawn from the model if hparams.generate=True.
- [original_test_image_x.png,reconstructed_test_image_x.png]: original vs reconstructed image (from the latent codes) if hparams.reconstruct=True.
CIFAR-10 Attentive VAE
python3 run.py --mode=eval --eval_log_dir=../cifar10/0_16layer_900_epochs/ --model_filepath=../cifar10/0_16layer_900_epochs/final_model.h5 --dataset=cifar10 --eval_hparams=num_gpus=2 --preproc_encoder_hparams=scale=[0,-2],use_nonlocal=True,nonlocop_hparams=[key_dim=32,query_dim=32],num_blocks=2,num_filters=128 --postproc_decoder_hparams=scale=[2,0],num_blocks=2,num_filters=128 --encoder_hparams=scale=[0,0],use_nonlocal=True,nonlocop_hparams=[key_dim=32,query_dim=32],num_blocks=2,num_filters=128 --decoder_hparams=scale=[0,0],use_nonlocal=True,nonlocop_hparams=[key_dim=32,query_dim=32],num_blocks=2,num_filters=128 --posterior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-5,noise_stddev=0.001 --prior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-1.0,noise_stddev=0.001 --decoder_attention_hparams=key_dim=20,query_dim=20,use_layer_norm=True --encoder_attention_hparams=key_dim=20,query_dim=20,use_layer_norm=True --hparams=layer_latent_shape=[16,16,20],num_layers=16,data_distribution=discretized_logistic_mixture
OMNIGLOT Attentive VAE
python3 run.py --mode=eval --eval_log_dir=../omniglot/0_15layer_400_epochs/ --model_filepath=../omniglot/0_15layer_400_epochs/final_model.h5 --dataset=omniglot --dataset_path=../omniglot_data/chardata.mat --preproc_encoder_hparams=scale=[0,0,-2],num_nodes=2,num_blocks=3,num_filters=32 --postproc_decoder_hparams=scale=[2,0,0],num_nodes=2,use_nonlocal=True,nonlocop_hparams=[key_dim=8,query_dim=8],num_blocks=3,num_filters=32 --merge_encoder_hparams=use_nonlocal=True,nonlocop_hparams=[key_dim=8,query_dim=8] --encoder_hparams=scale=[0,0],num_blocks=2,num_filters=32 --merge_decoder_hparams=use_nonlocal=True,nonlocop_hparams=[key_dim=8,query_dim=8] --decoder_hparams=scale=[0,0],num_blocks=2,num_filters=32 --posterior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-5 --prior_hparams=log_scale_upper_bound=5,log_scale_low_bound=-5 --decoder_attention_hparams=key_dim=8,query_dim=8,use_layer_norm=True --encoder_attention_hparams=key_dim=8,query_dim=8,use_layer_norm=False --hparams=layer_latent_shape=[16,16,20],num_layers=15,num_proc_blocks=1,data_distribution=bernoulli
Saved models, learning curves, and qualitative evaluations of the models can be found below:
- The omniglot was partitioned and preprocessed as in here, and it can also be downloaded here.
- The rest of the datasets can be downloaded through Tensorflow Datasets.
Please cite our paper, if you happen to use this codebase:
@inproceedings{apostolopoulou2021deep,
title={Deep Attentive Variational Inference},
author={Apostolopoulou, Ifigeneia and Char, Ian and Rosenfeld, Elan and Dubrawski, Artur},
booktitle={International Conference on Learning Representations},
year={2021}
}
Funding for the development of this library has been generously provided by the following sponsors: