Skip to content

Simple VAE implementation in tensorflow, jax and pytorch where both the encoder and decoder model use gaussian distributions.

Notifications You must be signed in to change notification settings

Jovana-Gentic/VAE_celeba

Repository files navigation

VAE_celeba

Jovana Gentić 🦆


In this notebook, we implemented a VAE where both the encoder and decoder model gaussian distributions. The model is trained on CelebA_10 64x64 images. Model is trained in tensorflow and supports multi-GPU. We created jax and pytorch versions of code for learning purposes.

Images before and after cropping and resizing for model training

About the model

Encoder is made of convolutions that downsample the image resolution until a certain point, after which we flatten the image and use a stack of dense layers to get the posterior distribution q(z|x).

Decoder starts off with dense layers to process the sample z, followed by an unflatten (reshape) operation into an activation of shape (B, h, w, C). The activation is then upsampled back to the original image size using a stack of resize-conv blocks. Resize-conv block is a simple nearest neighbord upsampling + convolutions, used to upsample images instead of deconvolution layers. This block is useful to avoid checkerboard artifacts: https://distill.pub/2016/deconv-checkerboard/

For the Loss, we use the Negative ELBO = -likelihood + KL_div.

  • likelihood = decoder_dist.log_pdf(targets)
  • KL_div = KL(posterior_dist || prior_dist)
  • The posterior_dist is the encoder distribution.
  • For simplicity, we set the prior distribution to be a simple standard Gaussian N(0, 1).

To help the model avoid a posterior collapse, we warmup the KL_div by linearly scaling it up over 10000 steps.

Generate

Pick prior distribution temperature (z_temp) and decoder distribution temperature (x_temp) to generate new images from prior distribution, pictures = model.generate(z_temp=1., x_temp=0.3)

z_temp: float, defines the temperature multiplier of the encoder stddev. Smaller z_temp makes the generated samples less diverse and more generic

x_temp: float, defines the temperature multiplier of the decoder stddev. Smaller x_temp makes the generated samples smoother, and loses small degree of information.

About

Simple VAE implementation in tensorflow, jax and pytorch where both the encoder and decoder model use gaussian distributions.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published