/VAE_celeba

Simple VAE implementation in tensorflow, jax and pytorch where both the encoder and decoder model use gaussian distributions.

Primary LanguageJupyter Notebook

VAE_celeba

Jovana Gentić 🦆
tensorflow notebook
jax notebook
pytorch notebook


In this notebook, we implemented a VAE where both the encoder and decoder model gaussian distributions. The model is trained on CelebA_10 64x64 images. Model is trained in tensorflow and supports multi-GPU. We created jax and pytorch versions of code for learning purposes.

Images before and after cropping and resizing for model training

About the model

Encoder is made of convolutions that downsample the image resolution until a certain point, after which we flatten the image and use a stack of dense layers to get the posterior distribution q(z|x).

Decoder starts off with dense layers to process the sample z, followed by an unflatten (reshape) operation into an activation of shape (B, h, w, C). The activation is then upsampled back to the original image size using a stack of resize-conv blocks. Resize-conv block is a simple nearest neighbord upsampling + convolutions, used to upsample images instead of deconvolution layers. This block is useful to avoid checkerboard artifacts: https://distill.pub/2016/deconv-checkerboard/

For the Loss, we use the Negative ELBO = -likelihood + KL_div.

  • likelihood = decoder_dist.log_pdf(targets)
  • KL_div = KL(posterior_dist || prior_dist)
  • The posterior_dist is the encoder distribution.
  • For simplicity, we set the prior distribution to be a simple standard Gaussian N(0, 1).

To help the model avoid a posterior collapse, we warmup the KL_div by linearly scaling it up over 10000 steps.

Generate

Pick prior distribution temperature (z_temp) and decoder distribution temperature (x_temp) to generate new images from prior distribution, pictures = model.generate(z_temp=1., x_temp=0.3)

z_temp: float, defines the temperature multiplier of the encoder stddev. Smaller z_temp makes the generated samples less diverse and more generic

x_temp: float, defines the temperature multiplier of the decoder stddev. Smaller x_temp makes the generated samples smoother, and loses small degree of information.