facebookresearch/multimodal

Train diffusion on MNIST

Closed this issue · 2 comments

Tried the Diffusion_Labs tutorial Train diffusion on MNIST after watching @pbontrager presentation at recent PyTorch Conference. The code works well. On Google Colab with GPU (free tier) it take 7 minutes per epoch. However it will be nice if the tutorial can be extended to cover how to save the trained model and use it for inference.

Hi @sudhir2016, thanks for your interest in the tutorial. Although the tutorial doesn't show how to save the model, I think the last cell should give an example of how to do inference. This just involves sampling from a normal distribution, calling the encoder to get the embeddings of each digit, and calling the decoder on the random noise with the digit embeddings as conditional inputs for the denoising process.

To save the model, you can just run e.g.

torch.save(encoder.state_dict(), "mnist_diffusion_encoder.pt")
torch.save(decoder.state_dict(), "mnist_diffusion_decoder.pt")

Then you can load state dicts into nn.Embedding and DDPModule classes as constructed in the notebook using load_state_dict(...). We will take a look at adding these details to the tutorial when we get a chance, please also feel free to open a PR yourself if you're interested.

Thank you for the prompt response.