/vl-vae

An implementation of VL-VAE in PyTorch.

Primary LanguagePythonMIT LicenseMIT

VL-VAE

VL-VAE is a transformer-based VAE architecture that supports progressive decoding through variable-length latent embeddings.

Examples

Progressive decoding examples from CelebA-HQ-256x256.

out2.mp4
out5.mp4
out6.mp4
out12.mp4

Architecture

VL-VAE uses a straightforward architecture consisting of two headless transformers that implement the encoder and decoder networks respectively. Unlike conventional autoencoders, the architecture does not neccessarily include downsampling layers. Instead, compression is enforced by randomly truncating the encoder's output (i.e. latent embeddings) during training. We sample truncation lengths according to an exponential distribution.

TODO

  • Experiment with alternative attention mechanisms (NAT, axial, etc).
  • Experiment with alternative positional embedding methods.
  • Experiment with alternative patch embeddings.
  • Scale up to 1024x1024 resolution.