Rayhane-mamah/Efficient-VDVAE

Discussion starter

Opened this issue · 1 comments

This umbrella issue tracks our work's current state and discusses the priority of potential TODOs. It is also a good place to ask any questions about the work.

Goals

  • Enable very deep VAE based models to train faster and with less compute, while only applying the simplest modifications.
  • Provide all 22 pre-trained models in both Pytorch and JAX. (Will be added soon)

Potential TODOs (based on need)

  • Make data loaders Out-Of-Core (OOC) in Pytorch (For RAM efficiency on large datasets)
  • Make data loaders Out-Of-Core (OOC) in JAX (For RAM efficiency on large datasets)
  • Add Fréchet-Inception Distance (FID) and Inception Score (IS) as measures for sample quality performance.
  • Improve the format of the encoded dataset used in downstream tasks (output of encoding mode, if there is a need)
  • Write a decoding mode API (if needed).

Notes:

  • Any feedback or questions on code, documentation or paper are most welcome.
  • Any suggestions to improve this repository and any requests for additional useful features are also welcome.
  • There are no plans of implementing Efficient-VDVAE in tensorflow (TF) as we faced graph scalability limits on TF since models are very deep.
  • We have heavily tested the robustness and stability of our approach, so changing the model/optimization hyper-parameters for memory load reduction should not introduce any drastic instabilities as to make the model untrainable. That is of course as long as the changes don't negate the important stability points we describe in the paper.

Thank you for considering our work. Please feel free to reach out! :)

Checkpoints status update (12/22) and (11/11 datasets):

  • We uploaded at least one model for each dataset (either in JAX or Pytorch). Missing checkpoints in either library will be added soon.