Reproducing results from paper
zwei-beiner opened this issue · 2 comments
zwei-beiner commented
Hi, I'm trying to reproduce Figure 1 (right subfigure for d=7) from this paper: https://arxiv.org/abs/1810.02733
However, I am getting different results: The W2 distance is much larger when computed with OTT than in the paper, and larger epsilon gives larger W2, which is opposite to the figure in the paper. (Note that the color coding is opposite between the two figures.)
Figure produced with the attached script:
import functools
import jax
import jax.numpy as jnp
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
import matplotlib.pyplot as plt
import numpy as np
# Samples from the hypercube
def make_samples(key, ndims, nsamples):
return jax.random.uniform(key, shape=(nsamples, ndims))
@jax.jit
def W(x, y, epsilon):
geom = pointcloud.PointCloud(x, y, epsilon=epsilon)
ot_prob = linear_problem.LinearProblem(geom)
solver = sinkhorn.Sinkhorn()
ot = solver(ot_prob)
return ot.primal_cost
def run():
key = jax.random.PRNGKey(0)
Epsilons = 10. ** jnp.arange(-3, 3)
Nsamples = np.int64(np.exp(np.linspace(1.0, 2.5, 200)))
ndims = 7
numiter = 300
@functools.partial(jax.jit, static_argnums=(2, 3))
def calc_W(key, epsilon, nsamples, ndims):
key, subkey = jax.random.split(key)
x = make_samples(subkey, ndims, nsamples)
key, subkey = jax.random.split(key)
y = make_samples(subkey, ndims, nsamples)
return W(x, y, epsilon)
# Calculate log of W2 distances for all options in the paper
results = jnp.log(jnp.asarray([[
jax.vmap(
lambda key, i: calc_W(key, epsilon, nsamples, ndims)
)(jax.random.split(key, numiter), jnp.arange(numiter))
for key, nsamples in zip(jax.random.split(key, len(Nsamples)), Nsamples)
] for key, epsilon in zip(jax.random.split(key, len(Epsilons)), Epsilons)
]))
fig, ax = plt.subplots()
for i, epsilon in enumerate(Epsilons):
ax.errorbar(np.log(Nsamples), jnp.mean(results[i], axis=-1), yerr=jnp.std(results[i], axis=-1), label=f"epsilon={epsilon:3f}")
ax.legend()
fig.tight_layout()
fig.savefig("results.png")
run()
michalk8 commented
- This seems to be wrong, consider using
jnp.logspace(1.0, 2.5, 200)
:
- You're compute the
$OT_\epsilon$ , not the Sinkhorn Divergence (SD)$\tilde{W}_\epsilon$ , please see here for the docs.
You can compute SD as:
from ott.tools import sinkhorn_divergence
@jax.jit
def W(x, y, epsilon):
sinkhorn_divergence.segment_sinkhorn_divergence(
pointcloud.PointCloud, x, y, epsilon=epsilon
).divergence
zwei-beiner commented
Thanks a lot, I was not aware of this distinction. Minor correction: I'm guessing the function should be sinkhorn_divergence
instead of segment_sinkhorn_divergence
. If so, you can close the issue.