catGAN
PyTorch implementation of Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks that was originally proposed by Jost Tobias Springenberg.
Results on CIFAR10
Note that in this repo, only the unsupervised version was implemented for now. I reaplced the orginal architecture with DCGAN and the results are more colorful than the original one.
From 0 to 100 epochs:
Prerequisites
- Python 2.7
- PyTorch v0.2.0
- Numpy
- SciPy
- Matplotlib
Getting Started
Installation
- Install PyTorh and the other dependencies
- Clone this repo:
git clone https://github.com/xinario/catgan_pytorch.git
cd catgan_pytorch
Train
- Download the cifar10 dataset (.png format from kaggle)
- Create a dataset folder to hold the images
mkdir -p ./datasets/cifar10/images
-
Move the extracted images into the newly created folder
-
Train a model:
python catgan_cifar10.py --data_dir ./datasets/cifar10 --name cifar10
All the generated plot and samples can be found in side ./results/cifar10
Training options
optional arguments:
--continue_train to continue training from the latest checkpoints if --netG and --netD are not specified
--netG NETG path to netG (to continue training)
--netD NETD path to netD (to continue training)
--workers WORKERS number of data loading workers
--num_epochs EPOCHS number of epochs to train for
More options can be found in side the training script.
Acknowledgments
Some of code are inspired and borrowed from wgan-gp, DCGAN, catGAN chainer repo