PythonOT/POT

Self distances with ot.bregman.empirical_sinkhorn2 higher than expected

Opened this issue · 0 comments

I am computing matrices of W_2 distances with ot.bregman.empirical_sinkhorn2 between point clouds centered along a curve. I expect that the distance from a point cloud to itself should be zero or close to zero. However, this is not the case and the self-distances are in fact higher than to some neighboring point clouds. This seems like an unexpected behavior and I am wondering if there is an underlying issue causing it. I've created a code snippet below which highlights the issue. Any input would be greatly appreciated.

To Reproduce

import ot
import numpy as np
import matplotlib.pyplot as plt

T = 40 # Number of point clouds
N = 25 # Number of points in each cloud
max_x = 4*np.pi # Cosine is evaluate on [0, max_x]
variance = 0.4 # The variance of the normal at each point cloud
epsilon = 0.5

def generate_cos(T=100, N=50, max_x=2*np.pi, variance=1):
    """
    Generates a dataset which follows a cosine wave. Point clouds are 2-D gaussians which
    are centered at a point on the cosine wave.
    :param T: Number of point clouds to generate
    :param N: Number of samples to take at each timepoint
    :param max_x: The right end of the cosine wave
    :variance: The variance of the gaussian distributions
    :return: The data matrix of shape (T, N, 2)
    """
    # Form the matrix [[x, cos(x)], ...]
    span = np.linspace(0, max_x, T, endpoint=True)
    y_vals = np.cos(span)
    means = np.zeros((T, 2))
    means[:, 0] = span
    means[:, 1] = y_vals

    # Sample a normal dist centered at each point
    x = np.zeros((T, N, 2))
    for i, mean in enumerate(means):
        dist = np.random.multivariate_normal(mean, np.eye(2)*variance, N)
        x[i, :, :] = dist

    return x

x = generate_cos(T=T, N=N, max_x=max_x, variance=variance)

dists = np.zeros(shape=(T, T))

for i in range(T):
    for j in range(i, T):
        d = np.sqrt(ot.bregman.empirical_sinkhorn2(x[i], x[j], epsilon, 
                                                   a=ot.unif(x[i].shape[0]), 
                                                   b=ot.unif(x[j].shape[0])))
        dists[i, j] = d

dists = dists + dists.T

plt.figure(figsize=(10, 10))
plt.matshow(dists[10:20, 10:20], fignum=1)
plt.title('$W_2$ distance matrix')
plt.colorbar()
plt.show()

Screenshots

dist_mat

Expected Behavior

I expect the diagonal to be close to zero.

Environment (please complete the following information):

  • OS: Linux
  • Python version: 3.11.5
  • How was POT installed (source, pip, conda): pip
  • POT version: 0.9.4

Output of code snippet:

Linux-4.18.0-513.5.1.el8_9.x86_64-x86_64-with-glibc2.28
Python 3.11.5 (main, Sep 22 2023, 15:34:29) [GCC 8.5.0 20210514 (Red Hat 8.5.0-20)]
NumPy 1.26.0
SciPy 1.11.3
[KeOps] Warning : Cuda libraries were not detected on the system or could not be loaded ; using cpu only mode
POT 0.9.4