• Finish implementing the UNet2D model in Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file.
  • Adapt the PNDMScheduler from diffusers for JAX: Use jnp arrays and make it stateless.
  • Add the KL module from (here)[] in file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence.
  • Add an inference loop in pipeline_stabel_diffusion. We should able to jit/pmap the loop to deploy on TPUs.