Unbalanced Stabilized Does Not Agree with Standard Unbalanced
tomaszkacprzak opened this issue · 1 comments
Describe the bug
In the unbalanced
module, the sinkhorn_stabilized_unbalanced
gives different results than sinkhorn_knopp_unbalanced
, when absorbing is used. If the absorbing is not used, then the results agree. The example in documentation does not trigger absorbing.
Code sample
import ot
import numpy as np
a=[.3, .7]
b=[.7, .3]
M=[[0., 1.], [1., 0.]]
reg = 0.01
reg_m = 100
Q1 = ot.sinkhorn_unbalanced(a, b, M, reg, reg_m, reg_type='entropy')
Q2 = ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type='entropy', tau=1000)
print()
print(np.round(Q1,6))
print(np.round(Q2,6))
Expected behavior
Q1 and Q2 should be the same.
Environment (please complete the following information):
- MacOS
- Python version: 3.11.6
- How was POT installed: pip
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
macOS-15.2-arm64-arm-64bit
Python 3.11.6 | packaged by conda-forge | (main, Oct 3 2023, 10:37:07) [Clang 15.0.7 ]
NumPy 1.26.0
SciPy 1.11.3
POT 0.9.5
Additional context
Changing the lines:
alpha = alpha + reg * nx.log(nx.max(u))
beta = beta + reg * nx.log(nx.max(v))
to:
alpha = alpha + reg * nx.log(u)
beta = beta + reg * nx.log(v)
seems to be solving the problem.
Hi @tomaszkacprzak,
Thank you for opening an issue. This behavior does not seem abnormal and essentially comes from numerical instability. Could you check whether you recover the same result, when setting method == "sinkhorn_stabilized"
in ot.sinkhorn_unbalanced
, than with sinkhorn_stabilized_unbalanced
?
Also tau=1e5
is the default value of the later function used in the wrapper unless you add tau=1000
as argument to ot.sinkhorn_unbalanced
.