ott-jax/ott

Sinkhorn iteration is not converging in A100 GPU

yexf308 opened this issue · 5 comments

I have tried the following code in A100 and other GPU. I found in other GPU, such as T4, V100, A40, this Sinkhorn iteration will converge, but in A100 GPU, the Sinkhorn iteration is stuck in some value until the maximum iteration is reached. Do anyone know what happened here?

To Reproduce in A100 GPU

n           = 50
𝜀           = 1
threshold   = 0.01

threshold   = 1e-2
blob_std    = 1

x, data_mem   = make_blobs(n_samples=n, n_features=4, centers=1, cluster_std=blob_std, random_state=1)
mu                 = np.ones((n,)) / n # size of M
nv                  = np.ones((n,)) / n # size of N
w                    = np.random.normal(size=[4,2])
w_int               = w + 1 * np.random.normal(size=[4,2])
y                      = x.dot(w) + 1 * np.random.normal(size=[n,2])
y_pred            = x.dot(w_int)
y_jx                = jnp.array(y)
pred_jx          = jnp.array(y_pred)

from ott import utils
geom = pointcloud.PointCloud(pred_jx, y_jx, epsilon=𝜀)
prob = linear_problem.LinearProblem(geom, a=mu, b=nv)
solver = sinkhorn.Sinkhorn(
        threshold=threshold, inner_iterations=200, lse_mode=True, max_iterations=1000, 
        norm_error=1,use_danskin=False, progress_fn=utils.default_progress_fn()
    )
out = solver(prob)

The screenshot of the output in A100 is
Screenshot 2024-03-01 at 11 43 32 AM

The screenshot of the output in T4 is
Screenshot 2024-03-01 at 11 52 46 AM

Hi @yexf308 , I can reproduce the above on A100 using:
image

As for the reason, I think there are things at play:

  1. The values of in the cost matrix is really big, it's in around [0.15, 417], which can lead to numerical instability. One way to solve this is to scale the cost matrix, e.g., by the mean as scale_cost='mean'. Other solution is to use 64-bit precision.
  2. XLA seems to generate different code for CPU/GPU/other accelerators, so I assume this is where the CPU/GPU discrepancy comes in; unfortunately this is out of our control

@michalk8 I have checked out.matrix and compared with the output error. The out.matrix is actually converged to desired coupling matrix within the threshold, the error calculating is actually in-correct. That is why I suspect there is some bugs. Thanks for the quick reply!

Here is the screenshot.

Screenshot 2024-03-01 at 12 52 35 PM

This issue is resolved if one uses double precision.