ott-jax/ott

Reproducing results from paper

zwei-beiner opened this issue · 2 comments

Hi, I'm trying to reproduce Figure 1 (right subfigure for d=7) from this paper: https://arxiv.org/abs/1810.02733

However, I am getting different results: The W2 distance is much larger when computed with OTT than in the paper, and larger epsilon gives larger W2, which is opposite to the figure in the paper. (Note that the color coding is opposite between the two figures.)

Figure produced with the attached script:
results

Figure in the paper:
paper

import functools
import jax
import jax.numpy as jnp
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

import matplotlib.pyplot as plt
import numpy as np

# Samples from the hypercube
def make_samples(key, ndims, nsamples):
    return jax.random.uniform(key, shape=(nsamples, ndims))

@jax.jit
def W(x, y, epsilon):
    geom = pointcloud.PointCloud(x, y, epsilon=epsilon)
    ot_prob = linear_problem.LinearProblem(geom)
    solver = sinkhorn.Sinkhorn()
    ot = solver(ot_prob)
    return ot.primal_cost

def run():
    key = jax.random.PRNGKey(0)
    Epsilons = 10. ** jnp.arange(-3, 3)
    Nsamples = np.int64(np.exp(np.linspace(1.0, 2.5, 200)))
    ndims = 7
    numiter = 300

    @functools.partial(jax.jit, static_argnums=(2, 3))
    def calc_W(key, epsilon, nsamples, ndims):
        key, subkey = jax.random.split(key)
        x = make_samples(subkey, ndims, nsamples)
        key, subkey = jax.random.split(key)
        y = make_samples(subkey, ndims, nsamples)
        return W(x, y, epsilon)

    # Calculate log of W2 distances for all options in the paper
    results = jnp.log(jnp.asarray([[
            jax.vmap(
                lambda key, i: calc_W(key, epsilon, nsamples, ndims)
            )(jax.random.split(key, numiter), jnp.arange(numiter))
            for key, nsamples in zip(jax.random.split(key, len(Nsamples)), Nsamples)
        ] for key, epsilon in zip(jax.random.split(key, len(Epsilons)), Epsilons)
    ]))


    fig, ax = plt.subplots()
    for i, epsilon in enumerate(Epsilons):
        ax.errorbar(np.log(Nsamples), jnp.mean(results[i], axis=-1), yerr=jnp.std(results[i], axis=-1), label=f"epsilon={epsilon:3f}")
    ax.legend()
    fig.tight_layout()
    fig.savefig("results.png")
run()
  1. This seems to be wrong, consider using jnp.logspace(1.0, 2.5, 200):
image
  1. You're compute the $OT_\epsilon$, not the Sinkhorn Divergence (SD) $\tilde{W}_\epsilon$, please see here for the docs.

You can compute SD as:

from ott.tools import sinkhorn_divergence

@jax.jit
def W(x, y, epsilon):
    sinkhorn_divergence.segment_sinkhorn_divergence(
        pointcloud.PointCloud, x, y, epsilon=epsilon
    ).divergence

Thanks a lot, I was not aware of this distinction. Minor correction: I'm guessing the function should be sinkhorn_divergence instead of segment_sinkhorn_divergence. If so, you can close the issue.