PythonOT/POT

entropic_partial_wasserstein not stable

Opened this issue · 1 comments

Describe the bug

The entropic_partial_wasserstein function produces nan when eps is small.

To Reproduce

import ot
import torch
import numpy as np

def compute_OT(M, alpha, beta, epsilon):
    s1, s2 = M.shape[0], M.shape[1]
    assert s1 == s2
    unif_vec = ot.unif(s1)
    a, b = unif_vec/beta, unif_vec
    pi_1_np = ot.partial.entropic_partial_wasserstein(a, b, M, m=alpha, reg=epsilon)
    print(f"Original: sum(pi) = {pi_1_np.sum():.4f}, alpha = {alpha:.4f}")


beta = 0.35
alpha = 0.01


M_1 = torch.load('M_1.pt')
print(f"M_1 norm = {np.linalg.norm(M_1):.2f}\n")

epsilon = 10.
compute_OT(M_1, alpha, beta, epsilon)

epsilon = 0.1
compute_OT(M_1, alpha, beta, epsilon)

Output

Original: sum(pi) = 0.0100, alpha = 0.0100
G:\Mycode\POT\ot\partial.py:698: RuntimeWarning: divide by zero encountered in divide 
  np.multiply(K, m / np.sum(K), out=K)
G:\Mycode\POT\ot\partial.py:698: RuntimeWarning: invalid value encountered in multiply
  np.multiply(K, m / np.sum(K), out=K)
Warning: numerical errors at iteration 0
Original: sum(pi) = nan, alpha = 0.0100

When eps=0.1, the output is Nan.

Expected behavior

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Windows
  • Python version: 3.10
  • How was POT installed (source, pip, conda): pip
  • Build command you used (if compiling from source):
  • Only for GPU related bugs:
    • CUDA version:
    • GPU models and configuration:
    • Any other relevant information:

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Additional context

I will create a pull request soon.