Group project for the Optimal Transport course at ENSAE teached by Marco Cuturi (spring 2023).
- Augustin Combes
- Gabriel Watkinson
The result of our experiments are presented in the main.ipynb
notebook. It is self contained and can be runned to replicate all of our experiments.
You can run the notebook interactively using Google Colab or Binder by clicking on the badges above.
The goal of this notebook is to explore whether Wasserstein-GAN (WGAN) can effectively approximate Wasserstein distances. WGAN, introduced in the 2017 paper by Arjovsky et al. [1], proposes a neural network-based proxy for the 1-Wasserstein distance, but it is unclear how well this approximation holds up in practice.
To investigate this question, we will implement the WGAN approach to solve Optimal Transport and compare it with other approaches, such as Sinkhorn divergence. Our aim is to determine if WGAN can compute a quantity that is truly similar to “true” optimal transport.
The WGAN paper proposes a new approach to learning a probability distribution by leveraging Optimal Transport theory. Traditionally, learning a probability distribution involves maximizing the likelihood on the data across a family of parametric densities, denoted as $(P_\theta){\theta \in \mathbb{R}}$. This is equivalent to minimizing the Kullback-Leibler divergence between the real distribution $\mathbb{P}r$ and the model distribution $\mathbb{P}\theta$: $KL(\mathbb{P}r|| \mathbb{P}\theta)$. However, in many cases, the model density $P\theta$ does not exist, and the Kullback-Leibler divergence is undefined.
To remedy to this problem, sampling directly from the target distribution
Arjovsky et al. propose a new approach based on Optimal Transport theory, expanding upon the idea of GANs. Traditional GANs are notoriously difficult to train, as a unstable equilibrium between the generator and discriminator is needed, and their results may suffer from mode collapse, in which the generator only produces a few samples that are very similar. WGAN is more stable and easier to tune than traditional GANs, using a proxy to the Wasserstein distance via neural nets. Our aim in this project is to investigate whether WGAN is a promising approach to solving Optimal Transport problems.
[1]: Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017. https://arxiv.org/abs/1701.07875
[2]: Kingma, Diederik P., and Max Welling. "Auto-encoding variational bayes." arXiv preprint arXiv:1312.6114 (2013). https://arxiv.org/abs/1312.6114
[3]: Goodfellow, Ian, et al. "Generative adversarial networks." Communications of the ACM 63.11 (2020): 139-144. https://arxiv.org/abs/1406.2661
Unlike in GANs, where the generator's loss is the binary cross-entropy between the discriminator's output and a target value indicating whether the generated sample is real or fake, WGAN use the Wasserstein-1 distance (also known as the Earth Mover's Distance) to measure the difference between the real and generated distributions:
where
However, this formulation is highly impractical, as it is not tractable and can't be used in practice. Instead, WGAN uses the Kantorovich-Rubinstein equivalent:
More precisely, we will search
Therefore, to approximate the Wasserstein distance, we will train a neural network to maximize the following objective function:
The resulting value should be close to the "true" Wasserstein distance between the two distributions.
In practice, denoting with
We will estimate the function
Enforcing the 1-Lipschitz constraint can be done using various techniques:
As proposed in the original WGAN paper, clipping the weights of the network is a simple way to ensure the constraint.
Denoting
Then, a sufficient way of ensuring each linear layer
As proposed in other follow-up papers, another way to ensure the constraint is to penalize when the parameters do not respect it directly in the loss.
That is, we then optimize on the following penalized optimization program:
where
We explore both of these two methods in this notebook.
If you want to run the notebook locally, feel free to use poetry to create and install the environment. You can do so by running the following commands:
git clone https://github.com/AugustinCombes/DeepWasserstein.git
cd DeepWasserstein
# curl -sSL https://install.python-poetry.org | python3 - # if you need to install poetry, see https://python-poetry.org/docs/ for details
poetry install # to create the environment from the poetry.lock file
# poetry shell # to spawn a shell in the environment
# pre-commit install # if you want to use pre-commit hooks
# poe update_jax_12 # to update jaxlib using cuda 12
Else, you can use the requirements.txt
file to install the dependencies with pip:
pip install -r requirements.txt