/VQ-VAE

Minimalist implementation of VQ-VAE in Pytorch

Primary LanguagePythonBSD 3-Clause "New" or "Revised" LicenseBSD-3-Clause

CVAE and VQ-VAE

This is an implementation of the VQ-VAE (Vector Quantized Variational Autoencoder) and Convolutional Varational Autoencoder. from Neural Discrete representation learning for compressing MNIST and Cifar10. The code is based upon pytorch/examples/vae.

pip install -r requirements.txt
python main.py

requirements

  • Python 3.6 (maybe 3.5 will work as well)
  • PyTorch 0.4
  • Additional requirements in requirements.txt

Usage

# For example
python3 main.py --dataset=cifar10 --model=vqvae --data-dir=~/.datasets --epochs=3

Results

All images are taken from the test set. Top row is the original image. Bottom row is the reconstruction.

k - number of elements in the dictionary. d - dimension of elements in the dictionary (number of channels in bottleneck).

  • MNIST (k=10, d=64)

mnist

  • CIFAR10 (k=128, d=256)

CIFAR10

  • Imagenet (k=512, d=128)

imagenet

TODO:

Acknowledgement

tf-vaevae for a good reference.