/stable-diffusion-jax

Primary LanguagePythonMIT LicenseMIT

TODOs:

  • Finish implementing the UNet2D model in modeling_unte2d.py. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file.
  • Adapt the PNDMScheduler from diffusers for JAX: Use jnp arrays and make it stateless.
  • Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in modeling_vae.py file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence.
  • Add an inference loop in pipeline_stabel_diffusion. We should able to jit/pmap the loop to deploy on TPUs.