/FSQ-pytorch

A Pytorch Implementation of Finite Scalar Quantization

Primary LanguagePython

FSQ-pytorch (Finite Scalar Quantization https://arxiv.org/abs/2309.15505)

An unoffical Pytorch Implementation of Finite Scalar Quantization (https://arxiv.org/abs/2309.15505)

image

In our view, FSQ is a great idea, and we manage to quickly implement a reproduction on a minimal framework. We are impressed by how FSQ is not only simple and effective in its concept but also highly optimizable during actual training.

Experimental settings

We use the ImageNet dataset (128*128) for our experiments with the downsampling factor as 8. The encoder we employe is a simple neural network with four convolutional layers, and the decoder is symmetric to the encoder. This network architecture is highly similar to the structure of CogView's VQ-VAE. The implementation of the FSQ quantizer is mainly adapted from another GitHub repository.

Training

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node 8 train.py --quantizer fsq --levels 8 8 8 5 5 5

The levels can also take on other values, as shown in the table below.

image

Quantitative Results

We evaluate several metrics on the validation set, and the results are shown in the table below.

Codebook Size L1 loss Perceptual loss Codebook Usage CKPT levels
1k 0.2319 0.2597 100% CKPT 8 5 5 5
4k 0.2135 0.2299 100% CKPT 7 5 5 5 5
16k 0.1917 0.1931 100% CKPT 8 8 8 6 5
64k 0.1807 0.1761 99.94% CKPT 8 8 8 5 5 5

Qualitative Results

Comparison of input images and reconstructed images. The pictures comes from the valid set without any cherry pick.

image

Acknowledgement

Our code draws heavily from the first stage (VQVAE training) of Cogview2 and vector-quantize-pytorch, and we would like to thank these teams for their selfless sharing. And we also thank Wendi Zheng and Ming Ding for their very constructive suggestions.