/image-gpt

Pytorch Implementation of OpenAI's Image GPT, trained on MNIST and Fashion MNIST

Primary LanguagePython

Image GPT

PyTorch implementation of Image GPT, based on paper Generative Pretraining from Pixels (Chen et al.) and accompanying code.


Model-generated completions of half-images from test set. First column is input; last column is original image

Differences from original paper:

  • Uses 4-bit grayscale images instead of 9-bit RGB
  • 28x28 images are used instead of 32x32
  • Quantization is done naively using division, not KNN
  • Model is much smaller and can be trained with much less compute

According to their blog post, the largest model, iGPT-L (1.4 M parameters), was trained for 2500 V100-days. By greatly reducing the number of attention head, number of layers, and input size (which effects model size quadratically), we can train our own model (26 K parameters) on Fashion-MNIST on a single NVIDIA 2070 in less than 2 hours.

Usage

Pre-trained Models

Pre-trained models are located in models directory.

Prepare Data

To download and prepare data, run src/prepare_data.py. Omitting the --fashion argument will download normal MNIST. Images are downloaded and encoded with a 4-bit grayscale pallete.

python src/prepare_data.py --fashion

Training

Models can be trained using src/run.py with the train subcommand.

Generative Pre-training

python src/run.py train --name fmnist_gen

The following hyperparameters can also be provided. Smallest model from paper is shown for comparison.

Argument Default iGPT-S (Chen et al.)
--embed_dim 16 512
--num_heads 2 8
--num_layers 8 24
--num_pixels 28 32
--num_vocab 16 512
--batch_size 64 128
--learning_rate 0.01 0.01
--steps 25000 1000000

Classification Fine-tuning

Pre-trained models can be fine-tuned by passing the path to the pre-trained checkpoint to --pretrained, along with the --classify argument. I have found a small reduction in learning rate is necessary.

python src/run.py train \
    --name fmnist_clf  \
    --pretrained models/fmnist_gen.ckpt \
    --classify \
    --learning_rate 3e-3

Sampling

Figures like those seen above can be created using random images from test set:

# outputs to figure.png
python src/sample.py models/fmnist_gen.ckpt

Gifs like the one seen in my tweet can be made like so:

# outputs to out.gif
python src/gif.py models/fmnist_gen.ckpt