entropic_partial_wasserstein not stable
Opened this issue · 1 comments
wzm2256 commented
Describe the bug
The entropic_partial_wasserstein function produces nan when eps is small.
To Reproduce
import ot
import torch
import numpy as np
def compute_OT(M, alpha, beta, epsilon):
s1, s2 = M.shape[0], M.shape[1]
assert s1 == s2
unif_vec = ot.unif(s1)
a, b = unif_vec/beta, unif_vec
pi_1_np = ot.partial.entropic_partial_wasserstein(a, b, M, m=alpha, reg=epsilon)
print(f"Original: sum(pi) = {pi_1_np.sum():.4f}, alpha = {alpha:.4f}")
beta = 0.35
alpha = 0.01
M_1 = torch.load('M_1.pt')
print(f"M_1 norm = {np.linalg.norm(M_1):.2f}\n")
epsilon = 10.
compute_OT(M_1, alpha, beta, epsilon)
epsilon = 0.1
compute_OT(M_1, alpha, beta, epsilon)
Output
Original: sum(pi) = 0.0100, alpha = 0.0100
G:\Mycode\POT\ot\partial.py:698: RuntimeWarning: divide by zero encountered in divide
np.multiply(K, m / np.sum(K), out=K)
G:\Mycode\POT\ot\partial.py:698: RuntimeWarning: invalid value encountered in multiply
np.multiply(K, m / np.sum(K), out=K)
Warning: numerical errors at iteration 0
Original: sum(pi) = nan, alpha = 0.0100
When eps=0.1, the output is Nan.
Expected behavior
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Windows
- Python version: 3.10
- How was POT installed (source,
pip
,conda
): pip - Build command you used (if compiling from source):
- Only for GPU related bugs:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Additional context
wzm2256 commented
I will create a pull request soon.