A simple implementation of WWGAN and the Wasserstein image-space gradient penalty (Wasserstein ground metric).
Based on Marvin Cao's wgan-gp and Nathan Inkawhich's example.
The Wasserstein of Wasserstein loss for generative models, uses an optimal transport metric as the distance measure between images. This is then formulated in the WGAN-GP formulation.
Details can be found in our paper Wasserstein of Wasserstein Loss for Learning Generative Models.
Clone the repository using
$ git clone https://github.com/dukleryoni/WWGAN.git
The following packages are required to run the repo: PyTorch, torchvision, Scipy, Pillow, TensorFlow, TensorBoard.
For your convinience, you can create the suitable conda environment using wwgan_env.yml
by running
$ conda create --name name_of_wwgan_env --clone wwgan_env.yml
Download and extract the CelebA image dataset from google drive. See gdown
for downloading from google drive on command line.
In train.py
, line 34, change dataroot
to the correct path to the downloaded CelebA directory.
For evaluating the network using the Freceht Inception Distance (FID) score we use the TTUR
repository.
In the WWGAN
repository, we clone TTUR
and download the pre-computed FID statistics for CelebA:
$ mkdir Frechet_Inception_Distance
$ cd Frechet_Inception_Distance
$ git clone https://github.com/bioinf-jku/TTUR.git
$ wget http://bioinf.jku.at/research/ttur/ttur_stats/fid_stats_celeba.npz # get pre-computed stats for FID
Now in the WWGAN
repo train the model by simply running python train.py
The user can specify hyperparamters for diffferent runs, (e.g. --ngpu 2
for number of GPUs for training).
After training the user can inspect simple properties of the generator using the Jupyter notebook Analyze_generated_images.ipynb
.
The code for computing the Wasserstein gradient penalty is given in wwgan_utils.py
and invloves the most new implementation details. One can compute the Wsserstein ground metic by calling the calc_wgp()
function.