PythonOT/POT

`sinkhorn2` and its `functorch.vmap` compatibility

Opened this issue ยท 3 comments

๐Ÿš€ Feature

Making the ot.sinkhorn2 function compatible with functorch.vmap.

Motivation

I'm using the Python Optimal Transport library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn distance for that sample and its ground-truth value. What I was using before was a for-loop:

for i in range(len(P_batch)):
      if i == 0:
         loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
      loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)

but this is way too slow for my application. I was reading through functorch, and apparently I should have been able to use the vmap functionality.

losses = vmap(ot.sinkhorn2)(P, Q, C, epsilon)

But after wrapping my function in vmap, I get this weird error:

File /anaconda3/envs/my_env/lib/python3.8/site-packages/ot/bregman.py:505, in sinkhorn_knopp(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)
    502 v = b / KtransposeU
    503 u = 1. / nx.dot(Kp, v)
--> 505 if (nx.any(KtransposeU == 0)
    506         or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
    507         or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
    508     # we have reached the machine precision
    509     # come back to previous solution and quit loop
    510     warnings.warn('Warning: numerical errors at iteration %d' % ii)
    511     u = uprev

RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .

Pitch

Apparently, the data-dependent if-statement needs to be replaced with other alternatives. Any help is appreciated.

That is a good point but POT is implemeted in pure python with backend and geting rid tof conditional flows is going to be a pain.

Note that for what you want to compute (P sinkhorn in paralell with the same cost C) one does not need to do a loop/vmap and the sinkhorns can be impelmmented with already paralell matrix products with very little change in the sinkhorn_knopp function. We do not provide it in POT (maybe we will one day but we need to find the proper API) but feel free to reach me in the POT slack if you want some pointers.

Thanks @rflamary! I wanted to join the POT Slack, but unfortunately it seems that the workspace invite link hasn't been shared. Could you send me the POT Slack invite? Thanks.