/pytorch-wgan-gp

A pytorch implementation of WGAN-GP

Primary LanguageJupyter NotebookMIT LicenseMIT

Pytorch WGAN-GP

This is a pytorch implementation of Improved Training of Wasserstein GANs. Most of the code was inspired by this repository by EmilienDupont.

Training

To train on the MNIST dataset, run

python main.py --dataset mnist --epochs 200

For the FashionMNIST dataset, run

python main.py --dataset fashion --epochs 200

You cans also set up a generator and discriminator pair and use the WGANGP class:

wgan = WGANGP(generator, discriminator,
              g_optimizer, d_optimizer,
              latent_shape, dataset_name)
wgan.train(data_loader, n_epochs)

The argument latent_shape is the shape whatever the generator's forward function accepts as input.

The training process is monitored by tensorboardX.

Results

Here is the training history for both datasets:

MNIST losses fashion losses

Two gifs of the training process:

MNIST training gif fashion training gif

Interpolation in latent space

We can generate samples going smoothly from one class to another by interpolating points on the latent space (done in this notebook):

MNIST interpolation fashion interpolation

The weights of the models are on the saved_models folder.