- Finish implementing the
UNet2D
model inmodeling_unte2d.py
. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file. - Adapt the
PNDMScheduler
fromdiffusers
for JAX: Usejnp
arrays and make it stateless. - Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in
modeling_vae.py
file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence. - Add an inference loop in
pipeline_stabel_diffusion
. We should able tojit
/pmap
the loop to deploy on TPUs.