/dalle-lightning

Refactoring dalle-pytorch and taming-transformers for TPU VM

Primary LanguagePython

Text-to-Image Translation (DALL-E) for TPU in Pytorch

Refactoring Taming Transformers and DALLE-pytorch for TPU VM with Pytorch Lightning

Requirements

pip install -r requirements.txt

Data Preparation

Place any image dataset with ImageNet-style directory structure (at least 1 subfolder) to fit the dataset into pytorch ImageFolder.

Training VQVAEs

You can easily test main.py with randomly generated fake data.

python train_vae.py --use_tpus --fake_data

For actual training provide specific directory for train_dir, val_dir, log_dir:

python train_vae.py --use_tpus --train_dir [training_set] --val_dir [val_set] --log_dir [where to save results]

Training DALL-E

python train_dalle.py --use_tpus --train_dir [training_set] --val_dir [val_set] --log_dir [where to save results] --vae_path [pretrained vae] --bpe_path [pretrained bpe(optional)]

TODO

  • Refactor Encoder and Decoder modules for better readability
  • Refactor VQVAE2
  • Add Net2Net Conditional Transformer for conditional image generation
  • Refactor, optimize, and merge DALL-E with Net2Net Conditional Transformer
  • Add Guided Diffusion + CLIP for image refinement
  • Add VAE converter for JAX to support dalle-mini
  • Add DALL-E colab notebook
  • Add RBGumbelQuantizer
  • Add HiT

ON-GOING

  • Test large dataset loading on TPU Pods
  • Change current DALL-E code to fully support latest updates from DALLE-pytorch

DONE

  • Add VQVAE, VQGAN, and Gumbel VQVAE(Discrete VAE), Gumbel VQGAN
  • Add VQVAE2
  • Add EMA update for Vector Quantization
  • Debug VAEs (Single TPU Node, TPU Pods, GPUs)
  • Resolve SIGSEGV issue with large TPU Pods pytorch-xla #3028
  • Add DALL-E
  • Debug DALL-E (Single TPU Node, TPU Pods, GPUs)
  • Add WebDataset support
  • Add VAE Image Logger by modifying pl_bolts TensorboardGenerativeModelImageSampler()
  • Add DALLE Image Logger by modifying pl_bolts TensorboardGenerativeModelImageSampler()
  • Add automatic checkpoint saver and resume for sudden (which happens a lot) TPU restart
  • Reimplement EMA VectorQuantizer with nn.Embedding
  • Add DALL-E colab notebook by afiaka87
  • Add Normed Vector Quantizer by GallagherCommaJack
  • Resolve SIGSEGV issue with large TPU Pods pytorch-xla #3068
  • Debug WebDataset functionality

BibTeX

@misc{oord2018neural,
      title={Neural Discrete Representation Learning}, 
      author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
      year={2018},
      eprint={1711.00937},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{razavi2019generating,
      title={Generating Diverse High-Fidelity Images with VQ-VAE-2}, 
      author={Ali Razavi and Aaron van den Oord and Oriol Vinyals},
      year={2019},
      eprint={1906.00446},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{esser2020taming,
      title={Taming Transformers for High-Resolution Image Synthesis}, 
      author={Patrick Esser and Robin Rombach and Björn Ommer},
      year={2020},
      eprint={2012.09841},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{ramesh2021zeroshot,
    title   = {Zero-Shot Text-to-Image Generation}, 
    author  = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
    year    = {2021},
    eprint  = {2102.12092},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}