gromov_barycenters function always returns the zero matrix
youssef62 opened this issue · 3 comments
Describe the bug
gromov_barycenters
function always returns the zero matrix as cost matrix.
I executed the code below multiple times and it always returns 0.
Screenshots
Code sample
from ot.gromov import gromov_barycenters
n = 5
g1 = nx.erdos_renyi_graph(n,.8)
g2 = nx.erdos_renyi_graph(n,.8)
c1 = nx.adjacency_matrix(g1).toarray()
c2 = nx.adjacency_matrix(g2).toarray()
print(c1)
print(c2)
b , log = gromov_barycenters(n,[c1,c2] ,log = True )
print(b)
print(log)
Expected behavior
A non zero matrix.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): MacOs , also tried on collab
- Python version:Python 3.11.5
- How was POT installed (source,
pip
,conda
): pip
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
macOS-12.5.1-arm64-arm-64bit
Python 3.11.5 (main, Aug 24 2023, 15:09:32) [Clang 14.0.0 (clang-1400.0.29.202)]
NumPy 1.26.0
SciPy 1.11.3
POT 0.9.1
Hello @youssef62 ,
The problem comes from the fact that 'c1 = nx.adjacency_matrix(g1).toarray()' as by default 'int64' as type. So the gromov wasserstein solver called several times in 'gromov_barycenters' will always convert the outputted transport plan to 'int64' i.e a 0 matrix.
You can simply fix your code on your side by doing e.g 'c1 = nx.adjacency_matrix(g1).toarray().astype(np.float64)'.
On our side, we could fix this bug by requesting to provide float inputs at least, no matter the backend.
I think it makes sens to use the type given in the inpu for the plan (this is a nightmare to handle with backend otherwise. because what float type do you use (float32, float34?) better to let te user choose with teh input) But we definitely need to stet this clearly in the documentation and maybe to use a warning when the type is integer I thought we did that for emd but we shoudl check and do that also form gromov.
Got it, thanks Rémi. Indeed for ot.emd
, it is stated in the documentation plus there is a warning in the function. I will open a PR to handle that for gw solvers.