Self distances with ot.bregman.empirical_sinkhorn2 higher than expected
Opened this issue · 0 comments
I am computing matrices of W_2 distances with ot.bregman.empirical_sinkhorn2 between point clouds centered along a curve. I expect that the distance from a point cloud to itself should be zero or close to zero. However, this is not the case and the self-distances are in fact higher than to some neighboring point clouds. This seems like an unexpected behavior and I am wondering if there is an underlying issue causing it. I've created a code snippet below which highlights the issue. Any input would be greatly appreciated.
To Reproduce
import ot
import numpy as np
import matplotlib.pyplot as plt
T = 40 # Number of point clouds
N = 25 # Number of points in each cloud
max_x = 4*np.pi # Cosine is evaluate on [0, max_x]
variance = 0.4 # The variance of the normal at each point cloud
epsilon = 0.5
def generate_cos(T=100, N=50, max_x=2*np.pi, variance=1):
"""
Generates a dataset which follows a cosine wave. Point clouds are 2-D gaussians which
are centered at a point on the cosine wave.
:param T: Number of point clouds to generate
:param N: Number of samples to take at each timepoint
:param max_x: The right end of the cosine wave
:variance: The variance of the gaussian distributions
:return: The data matrix of shape (T, N, 2)
"""
# Form the matrix [[x, cos(x)], ...]
span = np.linspace(0, max_x, T, endpoint=True)
y_vals = np.cos(span)
means = np.zeros((T, 2))
means[:, 0] = span
means[:, 1] = y_vals
# Sample a normal dist centered at each point
x = np.zeros((T, N, 2))
for i, mean in enumerate(means):
dist = np.random.multivariate_normal(mean, np.eye(2)*variance, N)
x[i, :, :] = dist
return x
x = generate_cos(T=T, N=N, max_x=max_x, variance=variance)
dists = np.zeros(shape=(T, T))
for i in range(T):
for j in range(i, T):
d = np.sqrt(ot.bregman.empirical_sinkhorn2(x[i], x[j], epsilon,
a=ot.unif(x[i].shape[0]),
b=ot.unif(x[j].shape[0])))
dists[i, j] = d
dists = dists + dists.T
plt.figure(figsize=(10, 10))
plt.matshow(dists[10:20, 10:20], fignum=1)
plt.title('$W_2$ distance matrix')
plt.colorbar()
plt.show()
Screenshots
Expected Behavior
I expect the diagonal to be close to zero.
Environment (please complete the following information):
- OS: Linux
- Python version: 3.11.5
- How was POT installed (source,
pip
,conda
): pip - POT version: 0.9.4
Output of code snippet:
Linux-4.18.0-513.5.1.el8_9.x86_64-x86_64-with-glibc2.28
Python 3.11.5 (main, Sep 22 2023, 15:34:29) [GCC 8.5.0 20210514 (Red Hat 8.5.0-20)]
NumPy 1.26.0
SciPy 1.11.3
[KeOps] Warning : Cuda libraries were not detected on the system or could not be loaded ; using cpu only mode
POT 0.9.4