/MaskGIT-pytorch

Pytorch implementation of MaskGIT: Masked Generative Image Transformer (https://arxiv.org/pdf/2202.04200.pdf)

Primary LanguagePythonMIT LicenseMIT

MaskGIT-pytorch

Pytorch implementation of MaskGIT: Masked Generative Image Transformer (https://arxiv.org/pdf/2202.04200.pdf)

results

Note: this is work in progress + the official implementation can be found at https://github.com/google-research/maskgit

MaskGIT is an extension to the VQGAN paper which improves the second stage transformer part (and leaves the first stage untouched). It switches the unidirectional transformer for a bidirectional transformer. The (second stage) training is pretty similar to BERT by randomly masking out tokens and trying to predict these using the bidirectional transformer (the original work used a GPT architecture randomly replaced tokens by other tokens). Different from BERT, the percentage for the masking is not fixed and uniformly distributed between 0 and 1 for each batch. Furhtermore, a new inference algorithm is suggested in which we start off by a completely masked-out image and then iteratively sample vectors where the model has a high confidence.

If you are only interested in the part of the code that comes from this paper check out transformer.py.

Results

(Training on https://www.kaggle.com/arnaud58/landscape-pictures, epochs=1000, bs=100)

Note: The training only encompasses about 3% of data of what the original paper trained on. (8.000 * 1.000 / 1.000.000 * 300 = 0.026) So longer training probably results in better outcomes. See also Issue #6

Run the code

The code is ready for training both the VQGAN and the Bidirectional Transformer and can also be used for inference

python training_vqgan.py

python training_transformer.py

(Make sure to edit the path for the dataset etc.)

TODO

  • Implement the gamma functions
  • Implement functions for image editing tasks: inpainting, extrapolation, image manipulation
  • Tune hyperparameters
  • Provide visual results
  • Pre-Norm instead of Post-Norm