A PyTorch implementation of Continuous Relaxation Training of Discrete Latent Variable Image Models.
Ensure you have Python 3.7 and PyTorch 1.2 or greater.
To train the VQVAE
model with 8 categorical dimensions and 128 codes per dimension
run the following command:
python train.py --model=VQVAE --latent-dim=8 --num-embeddings=128
To train the GS-Soft
model use --model=GSSOFT
.
Pretrained weights for the VQVAE
and GS-Soft
models can be found
here.
The VQVAE
model gets ~4.82 bpd while the GS-soft
model gets ~4.6 bpd.
As demonstrated in the paper, the codebook matrices are low-dimensional, spanning only a few dimensions:
Projecting the codes onto the first 3 principal components shows that the codes typically tile continuous 1- or 2-D manifolds: