/medfusion

Implementation of Medfusion - A latent diffusion model for medical image synthesis.

Primary LanguagePythonMIT LicenseMIT

Medfusion - Medical Denoising Diffusion Probabilistic Model

Paper

Please see: Diffusion Probabilistic Models beat GANs on Medical 2D Images

Figure: Medfusion


Figure: Eye fundus, chest X-ray and colon histology images generated with Medfusion (Warning color quality limited by .gif)

Demo

Link to streamlit app.

Install

Create virtual environment and install packages:
python -m venv venv
source venv/bin/activate
pip install -e .

Get Started

1 Prepare Data

2 Train Autoencoder

  • Go to scripts/train_latent_embedder_2d.py and import your Dataset.
  • Load your dataset with eg. SimpleDataModule
  • Customize VAE to your needs
  • (Optional): Train a VAEGAN instead or load a pre-trained VAE and set start_gan_train_step=-1 to start training of GAN immediately.

2.1 Evaluate Autoencoder

3 Train Diffusion

  • Go to scripts/train_diffusion.py and import/load your Dataset as before.
  • Load your pre-trained VAE or VAEGAN with latent_embedder_checkpoint=...
  • Use cond_embedder = LabelEmbedder for conditional training, otherwise cond_embedder = None

3.1 Evaluate Diffusion

Acknowledgment