/vqvae-pytorch

Unofficial implementation of VQVAE.

Primary LanguagePythonMIT LicenseMIT

vqvae-pytorch

Van Den Oord, Aaron, and Oriol Vinyals. "Neural discrete representation learning." Advances in neural information processing systems 30 (2017).

Unofficial implementations of VQVAE.


Installation

Clone this repo:

git clone https://github.com/xyfJASON/vqvae-pytorch.git
cd vqvae-pytorch

Create and activate a conda environment:

conda create -n vqvae python=3.11
conda activate vqvae

Install dependencies:

pip install -r requirements.txt

VQ Model

Training

accelerate-launch scripts/train_vqvae.py -c CONFIG [-e EXP_DIR] [--xxx.yyy zzz ...]
  • This repo uses the 🤗 Accelerate library for multi-GPUs/fp16 supports. Please read the documentation on how to launch the scripts on different platforms.
  • Results (logs, checkpoints, tensorboard, etc.) of each run will be saved to EXP_DIR. If EXP_DIR is not specified, they will be saved to runs/exp-{current time}/.
  • To modify some configuration items without creating a new configuration file, you can pass --key value pairs to the script.

For example, to train a vqvae on CelebA with default configurations:

accelerate-launch scripts/train_vqvae.py -c ./configs/vqvae-celeba.yaml -e ./runs/vqvae-celeba

Evaluation

accelerate-launch scripts/evaluate_vqmodel.py -c CONFIG \
                                              --weights WEIGHTS \
                                              [--bspp BSPP] \
                                              [--save_dir SAVE_DIR]
  • -c: path to the configuration file
  • --weights: path to the model weights
  • --bspp: batch size per process
  • --save_dir: directory to save the reconstructed images

Results

CelebA(64x64):

Model Codebook usage PSNR SSIM rFID
VQVAE (VQ loss) 56.45% 31.5486 0.9389 16.8227
VQVAE (EMA) 100% 32.0708 0.9459 15.5629
VQVAE (VQ loss) VQVAE (EMA)

It can be seen that the EMA codebook achieves better reconstruction quality and codebook usage than the VQ loss codebook.


Prior Model

Training

accelerate-launch scripts/train_prior_transformer.py -c CONFIG [-e EXP_DIR] --vqmodel.pretrained /path/to/vqmodel/checkpoint [--xxx.yyy zzz ...]

Sampling

accelerate-launch scripts/sample.py -c CONFIG \
                                    --vqmodel_weights VQMODEL_WEIGHTS \
                                    --prior_weights PRIOR_WEIGHTS \
                                    --n_samples N_SAMPLES \
                                    --save_dir SAVE_DIR \
                                    [--bspp BSPP] \
                                    [--topk TOPK]
  • -c: path to the configuration file
  • --vqmodel_weights: path to the vq model weights
  • --prior_weights: path to the prior model weights
  • --save_dir: directory to save the samples
  • --n_samples: number of samples to generate
  • --bspp: batch size per process
  • --topk: top-k sampling

Results

CelebA(64x64):

Model FID
Transformer + VQVAE (VQ loss) 23.7055
Transformer + VQVAE (EMA) 25.5597
Transformer + VQVAE (VQ loss) Transformer + VQVAE (EMA)