/Wasserstein2GenerativeNetworks

PyTorch implementation of "Wasserstein-2 Generative Networks" (ICLR 2021)

Primary LanguageJupyter NotebookMIT LicenseMIT

Wasserstein-2 Generative Networks

This is the official Python implementation of the ICLR 2021 paper Wasserstein-2 Generative Networks (paper on openreview) by Alexander Korotin, Vahe Egizarian, Arip Asadulaev, Alexander Safin and Evgeny Burnaev.

The repository contains reproducible PyTorch source code for computing optimal transport maps (and distances) in high dimensions via the end-to-end non-minimax method (proposed in the paper) by using input convex neural networks. Examples are provided for various real-world problems: color transfer, latent space mass transport, domain adaptation, style transfer.

Presentations

Citation

@inproceedings{
  korotin2021wasserstein,
  title={Wasserstein-2 Generative Networks},
  author={Alexander Korotin and Vage Egiazarian and Arip Asadulaev and Alexander Safin and Evgeny Burnaev},
  booktitle={International Conference on Learning Representations},
  year={2021},
  url={https://openreview.net/forum?id=bEoxzW_EXsa}
}

Prerequisites

The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Tested with

torch==1.3.0 torchvision==0.4.1

The code might not run as intended in newer torch versions. Newer torchvision might conflict with FID score evaluation.

Related repositories

Repository structure

All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/). For convenience, the majority of the evaluation output is preserved. Auxilary source code is moved to .py modules (src/).

Experiments

  • notebooks/W2GN_toy_experiments.ipynb -- toy experiments (2D: Swiss Roll, 100 Gaussuans, ...);
  • notebooks/W2GN_gaussians_high_dimensions.ipynb -- optimal maps between Gaussians in high dimensions;
  • notebooks/W2GN_latent_space_optimal_transport.ipynb -- latent space optimal transport for generating CelebA 64x64 aligned images;
  • notebooks/W2GN_domain_adaptation.ipynb -- domain adaptation for MNIST-USPS digits datasets;
  • notebooks/W2GN_color_transfer.ipynb -- cycle monotone pixel-wise image-to-image color transfer (example images are provided in data/color_transfer/);
  • notebooks/W2GN_style_transfer.ipynb -- cycle monotone image dataset-to-dataset style transfer (used datasets are publicitly available at the official CycleGan repo);

Input convex neural networks

  • src/icnn.py -- modules for Input Convex Neural Network architectures (DenseICNN, ConvICNN);

Poster

  • poster/W2GN_poster.png -- poster (landscape format)
  • poster/W2GN_poster.svg -- source file for the poster

Results

Toy Experiments

Transforming single Gaussian to the mixture of 100 Gaussuans without mode dropping/collapse (and some other toy cases).

Optimal Transport Maps between High Dimensional Gaussians

Assessing the quality of fitted optimal transport maps between two high-dimensional Gaussians (tested in dim up to 4096). The metric is Unexplained Variance Percentage (UVP, %).

2 4 8 16 32 64 128 256 512 1024 2048 4096
Large-scale OT <1 3.7 7.5 14.3 23 34.7 46.9 >50 >50 >50 >50 >50
Wasserstein-2 GN <1 <1 <1 <1 <1 <1 1 1.1 1.3 1.7 1.8 1.5

Latent Space Optimal Transport

CelebA 64x64 generated faces. The quality of the model highly depends on the quality of the autoencoder. Use notebooks/AE_Celeba.ipynb to train MSE or perceptual AE (on VGG features, to improve AE visual quality).
Pre-trained autoencoders: MSE-AE [Goodle Drive, Yandex Disk], VGG-AE [Google Drive, Yandex Disk].

Combining simple pre-trained MSE autoencoder with W2GN is enough to surpass Wasserstein GAN model in Freschet Inception Distance Score (FID).

AE Reconstruct AE Raw Decode AE + W2GN WGAN
FID Score 23.35 86.66 43.35 45.23

Perceptual VGG autoencoder combined with W2GN provides nearly State-of-the-art FID (compared to Wasserstein GAN with Quadratic Cost).

AE Reconstruct AE Raw Decode AE + W2GN WGAN-QC
FID Score 7.5 31.81 17.21 14.41

Image-to-Image Color Transfer

Cycle monotone color transfer is applicable even to gigapixel images!

Domain Adaptation

MNIST-USPS domain adaptation. PCA Visualization of feature spaces (see the paper for metrics).

Unpaired Image-to-Image Style Transfer

Optimal transport map in the space of images. Photo2Cezanne and Winter2Summer datasets are used.

Credits