PythonOT/POT

The distance between two of the same GMMs is not 0

GilgameshD opened this issue · 7 comments

Describe the bug

The distance between two of the same GMMs is not 0. Sometimes the distance could be as large as 1e-3 when I use my own data. Is this because of the numerical problem?

To Reproduce

import numpy as np
import torch
import ot


if __name__ == "__main__":
    K = 10
    D = 300
    pi0 = np.random.rand(K)
    pi0 /= np.sum(pi0)
    mu0 = np.random.rand(K, D)
    S0 = np.eye(D)[None].repeat(K, axis=0)

    pi0 = torch.as_tensor(pi0, dtype=torch.float32)
    mu0 = torch.as_tensor(mu0, dtype=torch.float32)
    S0 = torch.as_tensor(S0, dtype=torch.float32)

    pi1 = pi0.clone()
    mu1 = mu0.clone()
    S1 = S0.clone()

    print((pi0 == pi1).all())
    print((mu0 == mu1).all())
    print((S0 == S1).all())

    dist = ot.gmm.gmm_ot_loss(mu0, mu1, S0, S1, pi0, pi1)
    print(dist)

The output distance of the above code is 1.2001e-05.

Expected behavior

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Ubuntu 22.04
  • Python version: 3.10
  • How was POT installed (source, pip, conda): source

This is interesting. Could it come form the dist_bures_squared function that might not be exactly 0 on teh diagonal @eloitanguy ?

that is interesting, i'll look into it.

Hi, thanks for your Issue, I managed to reproduce it.
The issue stems from the fact that in this example (with np.random.seed(0)), as @rflamary suggested, torch.diag(ot.gmm.dist_bures_squared(mu0, mu1, S0, S1)) is not the zero vector as it should be, and it turns out that it is because ot.dist(mu0, mu1) has nonzero diagonal entries (10^(-5), as is coherent with the final GMM distance of roughly 10^(-5) instead of numerical 0).
If instead of torch.float32 you take torch.float64, the 10^(-5) diagonal entries in ot.dist(mu0, mu1) become 10^(-14) which is acceptable. It seems that is imprecision is somehow due to numerical imprecision in ot.dist when using torch.float32.

Thanks for identifying the problem! Is there any other solution rather than using torch.float64?

Hi, I don't really have other ideas, but maybe @rflamary would know?
I know that ot.dist performs a check to verify if the data matrices are the same object, in which case it enforces the diagonal to be 0. This does not solve your issue, but it's closely related so I'm bringing it up anyway.

It is well known that matrix factorization in float 32 bits leads to numerical errors. I'm not sure we can do anything. I wouldn't worry if you optimize you should still have the minimum at the same position (up to a numerical precision). The same object truck is virtual especially if you want to optimize stuff.

Note that we are open to numerical precision tricks if you know some and want to contribute it to POT.

Thanks for your help! @rflamary @eloitanguy