wgan-cifar10
An unofficial implementation of (improved) WGAN in PyTorch for CIFAR-10 image data.
Requirements
- A modern NVIDIA GPU
- Docker
- Docker Compose
- nvidia-docker
- nvidia-docker-compose
Usage
nvidia-docker-compose run --rm pytorch bin/train.py
Outputs from training are written to files in out/
.
There are a number of command line options which can be used to configure the training process:
--epochs N number of epochs to train (default=1000)
--gen-iters N generator iterations per epoch (default=100)
--disc-iters N discriminator iterations per generator iteration (default=5)
--batch-size N input batch size (default=64)
--disc-lr LR discriminator learning rate (default=2e-4)
--gen-lr LR generator learning rate (default=2e-4)
--unimproved disable gradient penalty and use weight clipping instead
References
- Wasserstein GAN, Arjovsky et al.
- Improved Training of Wasserstein GANs, Gulrajani et al.
Related repositories
- martinarjovsky/WassersteinGAN - Official repository for the original WGAN paper (PyTorch).
- igul222/improved_wgan_training - Official repository for the improved WGAN training paper (TensorFlow).
- caogang/wgan-gp - Unofficial partial port of the improved WGAN code (PyTorch).