/pytorch-gan-collections

PyTorch implementation of DCGAN, WGAN-GP and SNGAN.

Primary LanguagePython

Collections of GANs

Pytorch implementation of unsupervised GANs.

For more defails about calculating Inception Score and FID using pytorch can be found in pytorch-inception-score-fid

Models

  • DCGAN
  • WGAN
  • WGAN-GP
  • SN-GAN

Requirements

  • Initialize metric submoudle
    git submodule update --init
    
  • Install python packages
    pip install -U pip setuptools
    pip install -r requirements.txt

Results

Model Dataset Inception Score FID
DCGAN CIFAR10 6.04(0.05) 47.90
WGAN(CNN) CIFAR10 6.64(0.6) 33.27
WGAN-GP(CNN) CIFAR10 7.47(0.06) 24.00
WGAN-GP(ResNet) CIFAR10 7.74(0.10) 21.89
SNGAN(CNN) CIFAR10 7.44(0.11) 24.94
SNGAN(ResNet) CIFAR10 8.22(0.13) 16.24

Examples

  • DCGAN

    dcgan_gif dcgan_png

  • WGAN(CNN)

    wgan_gif wgan_png

  • WGAN-GP(CNN)

    wgangp_cnn_gif wgangp_cnn_png

  • WGAN-GP(ResNet)

    wgangp_res_gif wgangp_res_png

  • SNGAN(CNN)

    sngan_cnn_gif sngan_cnn_png

  • SNGAN(ResNet)

    sngan_res_gif sngan_res_png

Reproduce

  • Download cifar10.test.npz for calculating FID score. Then, create folder stats for the npz files

    stats
    ├── cifar10.test.npz
    ├── cifar10.train.npz
    └── stl10.unlabeled.48.npz
    
  • Train from scratch

    # DCGAN
    python dcgan.py --flagfile ./config/DCGAN_CIFAR10.txt
    # WGAN(CNN)
    python wgan.py --flagfile ./config/WGAN_CIFAR10_CNN.txt
    # WGAN-GP(CNN)
    python wgangp.py --flagfile ./config/WGANGP_CIFAR10_CNN.txt
    # WGAN-GP(ResNet)
    python wgangp.py --flagfile ./config/WGANGP_CIFAR10_RES.txt
    # SNGAN(CNN)
    python sngan.py --flagfile ./config/SNGAN_CIFAR10_CNN.txt
    # SNGAN(ResNet)
    python sngan.py --flagfile ./config/SNGAN_CIFAR10_RES.txt

    Though the training procedures of different GANs are almost identical, I still separate different methods into different files for clear reading.

Learning curve

inception_score_curve fid_curve

Change Log

  • 2021-04-16
    • Update pytorch to 1.8.1
    • Move metrics to submodule.
    • Evaluate FID on CIFAR10 test set instead of training set.
    • Fix cifar10.test.npz download link and sample images.