Unexpected behavior of `Backend.sqrtm` and a minor bug in `ot.gaussian.bures_wasserstein_distance`
framunoz opened this issue · 2 comments
Describe the bug
The implementation of the method Backend.sqrtm
may not be correct. I did the following proof:
import torch
import ot
torch.manual_seed(42)
z = torch.randn(100000, 128)
C = torch.cov(z.T)
nx = ot.backend.get_backend(C)
C12 = nx.sqrtm(C)
torch.allclose(C12 @ C12, C) # It should be True
>>> False
Futhermore, sometimes the array throws an empty array.
This affect the function ot.gaussian.bures_wasserstein_distance
, that by the pass, I realised that there are a bug in that function: The B
must be in squared, as said in the documentation and the reference (Cuturi, Computational Optimal Transport).
To Reproduce
Steps to reproduce the behavior:
- Run the script above
Code sample
import torch
import ot
torch.manual_seed(3137)
z = torch.randn(129, 128)
C = torch.cov(z.T)
nx = ot.backend.get_backend(C)
C12 = nx.sqrtm(C)
print(C12)
>>> tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]])
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): WSL
- Python version: 3.12.2
- How was POT installed (source,
pip
,conda
): pip - Build command you used (if compiling from source):
- Only for GPU related bugs:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-5.15.167.4-microsoft-standard-WSL2-x86_64-with-glibc2.35
Python 3.12.2 (main, Feb 25 2024, 16:35:05) [GCC 11.4.0]
NumPy 1.26.4
SciPy 1.13.0
POT 0.9.3
Additional context
I see that exist an implementation of the sqrtm right here: photosynthesis-team/piq#190
And with this implementation, the two bugs mentioned above desapears. In fact, every time I run the first block code, always throws True, and the second, at leat gives an array of float. If you want, I can make a PR.
Yep, I see that with double gets a good performance.
But, for the other problem:
I realised that there are a bug in that function: The
B
must be in squared, as said in the documentation and the reference (Cuturi, Computational Optimal Transport).
It might be good to take a look at it.