The distance between two of the same GMMs is not 0
GilgameshD opened this issue · 7 comments
Describe the bug
The distance between two of the same GMMs is not 0. Sometimes the distance could be as large as 1e-3 when I use my own data. Is this because of the numerical problem?
To Reproduce
import numpy as np
import torch
import ot
if __name__ == "__main__":
K = 10
D = 300
pi0 = np.random.rand(K)
pi0 /= np.sum(pi0)
mu0 = np.random.rand(K, D)
S0 = np.eye(D)[None].repeat(K, axis=0)
pi0 = torch.as_tensor(pi0, dtype=torch.float32)
mu0 = torch.as_tensor(mu0, dtype=torch.float32)
S0 = torch.as_tensor(S0, dtype=torch.float32)
pi1 = pi0.clone()
mu1 = mu0.clone()
S1 = S0.clone()
print((pi0 == pi1).all())
print((mu0 == mu1).all())
print((S0 == S1).all())
dist = ot.gmm.gmm_ot_loss(mu0, mu1, S0, S1, pi0, pi1)
print(dist)
The output distance of the above code is 1.2001e-05.
Expected behavior
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Ubuntu 22.04
- Python version: 3.10
- How was POT installed (source,
pip
,conda
): source
This is interesting. Could it come form the dist_bures_squared
function that might not be exactly 0 on teh diagonal @eloitanguy ?
that is interesting, i'll look into it.
Hi, thanks for your Issue, I managed to reproduce it.
The issue stems from the fact that in this example (with np.random.seed(0)
), as @rflamary suggested, torch.diag(ot.gmm.dist_bures_squared(mu0, mu1, S0, S1))
is not the zero vector as it should be, and it turns out that it is because ot.dist(mu0, mu1)
has nonzero diagonal entries (10^(-5), as is coherent with the final GMM distance of roughly 10^(-5) instead of numerical 0).
If instead of torch.float32
you take torch.float64
, the 10^(-5) diagonal entries in ot.dist(mu0, mu1)
become 10^(-14) which is acceptable. It seems that is imprecision is somehow due to numerical imprecision in ot.dist
when using torch.float32
.
Thanks for identifying the problem! Is there any other solution rather than using torch.float64?
Hi, I don't really have other ideas, but maybe @rflamary would know?
I know that ot.dist
performs a check to verify if the data matrices are the same object, in which case it enforces the diagonal to be 0. This does not solve your issue, but it's closely related so I'm bringing it up anyway.
It is well known that matrix factorization in float 32 bits leads to numerical errors. I'm not sure we can do anything. I wouldn't worry if you optimize you should still have the minimum at the same position (up to a numerical precision). The same object truck is virtual especially if you want to optimize stuff.
Note that we are open to numerical precision tricks if you know some and want to contribute it to POT.
Thanks for your help! @rflamary @eloitanguy