/Wasserstein1Benchmark

A set of tests for evaluating large-scale algorithms for Wasserstein-1 transport computation (NeurIPS'22).

Primary LanguageJupyter NotebookMIT LicenseMIT

Continuous Wasserstein-1 Benchmark

This is the official Python implementation of the NeurIPS 2022 (datasets and benchmarks track) paper Kantorovich Strikes Back! Wasserstein GANs are not Optimal Transport? (paper on openreview) by Alexander Korotin, Alexander Kolesov and Evgeny Burnaev.

The repository contains a set of continuous benchmark distributions for testing optimal transport (OT) solvers with the distance cost (the Wasserstein-1 distance, W1), the code for a dozen of WGAN dual OT solvers and their evaluation.

Presentations

Pre-requisites

The implementation is GPU-based. Single GPU GTX 1080 ti is enough to run each particular experiment. We tested the code with torch==1.10.2. The code might not run as intended in older/newer torch versions.

Related repositories

Repository structure

  • notebooks/ - jupyter notebooks with preview of benchmark pairs and the evaluation of OT solvers;
  • src/ - auxiliary source code for the OT solvers and the benchmark pairs;
  • metrics/ - results of the evaluation (cosine, L2, W1);
  • benchmarks/ - .pt checkpoints for continuous benchmark pairs.
  • checkpoints/ - .pt fitted dual potentials of OT solvers for the 2-dimensional experiment with 4 funnels.

Intended usage

Our benchmark pairs of distributions are designed to be used to test how well OT solvers recover the OT cost (W1) and the OT gradient. Use the following code to load the pairs and sample from them (assuming you work in notebooks/):

import sys
sys.path.append("..")

from src.map_benchmark import MixToOneBenchmark, Celeba64Benchmark, Cifar32Benchmark

# Load the high-dimensional benchmark for dimension 16 (2, 4, ..., 128)
#for number of funnels 64(4, 16, 64, 256)
benchmark =  MixToOneBenchmark(16,64)

# OR load the Celeba images benchmark pair 
# for number of funnels 16 (1,16), degree 10 (10, 100) correspondingly
# benchmark = Celeba64Benchmark( 16, 10)

# OR load the CIFAR-10 images benchmark pair
# for number of funnels 16 (1,16), degree 10 (10, 100) correspondingly
# benchmark = Cifar32Benchmark( 16, 10)

# Sample 32 random points from the benchmark distributions (unpaired, for training)
X = benchmark.input_sampler.sample(32)
Y = benchmark.output_sampler.sample(32)

# Sample 32 random points from the OT plan (paired, for testing only)
X, Y = benchmark.input_sampler.sample(32 ,flag_plan=True)

The examples of testing OT solvers are in notebooks/ folder, see below.

Evaluation of Existing WGAN OT Solvers

We provide all the code to evaluate existing dual WGAN OT solvers on our benchmark pairs. The qualitative results are shown below. For quantitative results, see the paper or metrics/ folder.

High-Dimensional Benchmarks

  • notebooks/get_benchmark_nd.ipynb -- previewing benchmark pairs;
  • notebooks/test_nd.ipynb -- testing dual WGAN OT solvers.

The scheme of the proposed method for constructing benchmark pairs.

The surfaces of the potential learned by the OT solvers in the 2-dimensional experiment with 4 funnels.

Celeba Images Benchmark Pairs

  • notebooks/get_benchmark_celeba.ipynb -- previewing benchmark pairs;
  • notebooks/test_celeba.ipynb -- testing dual WGAN OT solvers.

CIFAR-10 Images Benchmark Pairs

  • notebooks/get_benchmark_cifar.ipynb -- previewing benchmark pairs;
  • notebooks/test_cifar.ipynb -- testing dual WGAN OT solvers.

Credits