Harness the capabilities of PyTorch Lightning and Weights & Biases (W&B) to implement and train a variety of deep generative models. Seamlessly log and visualize your experiments for an efficient and impactful machine learning development process!
Lightning Generative Models is designed to provide an intuitive and robust framework for working with different types of generative models. It leverages the simplicity of PyTorch Lightning and the comprehensive tracking capabilities of Weights & Biases.
Figure modified from https://lilianweng.github.io/
Category | Model | Status | Paper Link |
---|---|---|---|
GANs | GAN | ✅ | Generative Adversarial Networks |
CGAN | ✅ | Conditional Generative Adversarial Nets | |
InfoGAN | ✅ | InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets | |
DCGAN | ✅ | Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks | |
LSGAN | ✅ | Least Squares Generative Adversarial Networks | |
WGAN | ✅ | Wasserstein GAN | |
WGAN-GP | ✅ | Improved Training of Wasserstein GANs | |
R1GAN | ✅ | Which Training Methods for GANs do actually Converge? | |
CycleGAN | - | Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks | |
VAEs | VAE | ✅ | Auto-Encoding Variational Bayes |
VQVAE | ✅ | Neural Discrete Representation Learning | |
Autoregressive Models | PixelCNN | - | Conditional Image Generation with PixelCNN Decoders |
Normalizing Flows | NICE | - | |
RealNVP | - | ||
Glow | - | ||
Diffusion Models | DDPM | ✅ | Denoising Diffusion Probabilistic Models |
DDIM | ✅ | Denoising Diffusion Implicit Models |
Weights & Biases provides tools to track experiments, visualize model training and metrics, and optimize machine learning models. Below is an example of the experiments logging interface on the Wandb platform.
Visit the Wandb Experiment Page for more details.
Tested on both Apple Silicon (M1 Max) and Ubuntu NVIDIA GPUs, supporting GPU acceleration and Distributed Data Parallel (DDP).
# Clone the repository
git clone https://github.com/seungjunlee96/lightning-generative-models.git
cd lightning-generative-models
# Set up a conda environment
conda create -n lightning-generative-models python=3.11 -y
conda activate lightning-generative-models
pip install -r environments/requirements.txt
# For contributors
pre-commit install
cd environments
chmod +x ./install_and_run_docker.sh
./install_and_run_docker.sh
Easily train different generative models using the config parser with train.py. Examples include:
# Train GAN
python train.py --config configs/gan/gan.json --experiment_name gan
# Train LSGAN
python train.py --config configs/gan/lsgan.json --experiment_name lsgan
# Train CGAN
python train.py --config configs/gan/cgan.json --experiment_name cgan
# Train InfoGAN
python train.py --config configs/gan/infogan.json --experiment_name infogan
# Train DCGAN
python train.py --config configs/gan/dcgan.json --experiment_name dcgan
# Train WGAN with weight clipping
python train.py --config configs/gan/wgan_cp.json --experiment_name wgan_cp
# Train WGAN with gradient penalty
python train.py --config configs/gan/wgan_gp.json --experiment_name wgan_gp
# Train GAN with R1 penalty
python train.py --config configs/gan/r1gan.json --experiment_name r1gan
# Train VAE
python train.py --config configs/vae/vae.json --experiment_name vae
# Train VQVAE
python train.py --config configs/vae/vqvae.json --experiment_name vqvae
# Train VQVAE with EMA (Exponential Moving Average)
python train.py --config configs/vae/vqvae_ema.json --experiment_name vqvae_ema
# Train DDPM
python train.py --config configs/diffusion/ddpm.json --experiment_name ddpm
# Train DDIM
python train.py --config configs/diffusion/ddim.json --experiment_name ddim
# ... and many more
Assess the quality and performance of your generative models with:
- The Inception Score measures the diversity and quality of images generated by a model. Higher scores indicate better image quality and variety, suggesting the model's effectiveness in producing diverse, high-fidelity images.
-
$G$ : Generative model -
$\mathbf{x}$ : Data samples generated by$G$ -
$p_g$ : Probability distribution of generated samples -
$p(y|\mathbf{x})$ : Conditional probability distribution of labels given sample$\mathbf{x}$ -
$p(y)$ : Marginal distribution of labels over the dataset -
$KL$ : Kullback-Leibler divergence
- The Fréchet Inception Distance evaluates the quality of generated images by comparing the feature distribution of generated images to that of real images. Lower scores indicate that the generated images are more similar to real images, implying higher quality.
-
$\mu_x$ ,$\Sigma_x$ : Mean and covariance of the feature vectors of real images -
$\mu_g$ ,$\Sigma_g$ : Mean and covariance of the feature vectors of generated images
- The Kernel Inception Distance computes the distance between the feature representations of real and generated images. Lower KID scores suggest a higher similarity between the generated images and real images, indicating better generative model performance.
-
$k$ : Kernel function measuring similarity between image features -
$x_i$ ,$x_j$ : Feature vectors of real images -
$y_i$ ,$y_j$ : Feature vectors of generated images -
$m$ ,$n$ : Number of real and generated images respectively
- Description: PRD offers a novel way to assess generative models by disentangling the evaluation of sample quality from the coverage of the target distribution. Unlike one-dimensional scores, PRD provides a two-dimensional evaluation that separately quantifies the precision and recall of a distribution, offering a more nuanced understanding of a model's performance.
All contributions are welcome! Open an issue for discussions or submit a pull request directly.
For queries or feedback, email me at lsjj096@gmail.com.
This repo is highly motivated by below amazing works: