This is the official Python
implementation of the NeurIPS 2022 (datasets and benchmarks track) paper Kantorovich Strikes Back! Wasserstein GANs are not Optimal Transport? (paper on openreview) by Alexander Korotin, Alexander Kolesov and Evgeny Burnaev.
The repository contains a set of continuous benchmark distributions for testing optimal transport (OT) solvers with the distance cost (the Wasserstein-1 distance, W1), the code for a dozen of WGAN dual OT solvers and their evaluation.
- Lightning talk by Alexander Kolesov at NeurIPS 2022 (December 2022, EN);
- Talk by Alexander Korotin at AIRI workshop (15 December 2022, RU);
The implementation is GPU-based. Single GPU GTX 1080 ti is enough to run each particular experiment. We tested the code with torch==1.10.2
. The code might not run as intended in older/newer torch
versions.
- Repository for Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark paper.
notebooks/
- jupyter notebooks with preview of benchmark pairs and the evaluation of OT solvers;src/
- auxiliary source code for the OT solvers and the benchmark pairs;metrics/
- results of the evaluation (cosine, L2, W1);benchmarks/
-.pt
checkpoints for continuous benchmark pairs.checkpoints/
-.pt
fitted dual potentials of OT solvers for the 2-dimensional experiment with 4 funnels.
Our benchmark pairs of distributions are designed to be used to test how well OT solvers recover the OT cost (W1) and the OT gradient. Use the following code to load the pairs and sample from them (assuming you work in notebooks/
):
import sys
sys.path.append("..")
from src.map_benchmark import MixToOneBenchmark, Celeba64Benchmark, Cifar32Benchmark
# Load the high-dimensional benchmark for dimension 16 (2, 4, ..., 128)
#for number of funnels 64(4, 16, 64, 256)
benchmark = MixToOneBenchmark(16,64)
# OR load the Celeba images benchmark pair
# for number of funnels 16 (1,16), degree 10 (10, 100) correspondingly
# benchmark = Celeba64Benchmark( 16, 10)
# OR load the CIFAR-10 images benchmark pair
# for number of funnels 16 (1,16), degree 10 (10, 100) correspondingly
# benchmark = Cifar32Benchmark( 16, 10)
# Sample 32 random points from the benchmark distributions (unpaired, for training)
X = benchmark.input_sampler.sample(32)
Y = benchmark.output_sampler.sample(32)
# Sample 32 random points from the OT plan (paired, for testing only)
X, Y = benchmark.input_sampler.sample(32 ,flag_plan=True)
The examples of testing OT solvers are in notebooks/
folder, see below.
We provide all the code to evaluate existing dual WGAN OT solvers on our benchmark pairs. The qualitative results are shown below. For quantitative results, see the paper or metrics/
folder.
notebooks/get_benchmark_nd.ipynb
-- previewing benchmark pairs;notebooks/test_nd.ipynb
-- testing dual WGAN OT solvers.
The scheme of the proposed method for constructing benchmark pairs.
The surfaces of the potential learned by the OT solvers in the 2-dimensional experiment with 4 funnels.
notebooks/get_benchmark_celeba.ipynb
-- previewing benchmark pairs;notebooks/test_celeba.ipynb
-- testing dual WGAN OT solvers.
notebooks/get_benchmark_cifar.ipynb
-- previewing benchmark pairs;notebooks/test_cifar.ipynb
-- testing dual WGAN OT solvers.
- CelebA page with faces dataset;
- CIFAR-10 page with images dataset.
- UNet architecture for the mover;
- ResNet architectures for the generator;
- DC-GAN architecture for the discriminator;
- Geotorch package for weight orthgonalization in SO solver;
- Gradient Penalty for the regularization in GP solver;
- Spectral Normalization for the parametrization of SN solver;
- Sorting Out for the group sort activations in SO solver.