`sinkhorn2` and its `functorch.vmap` compatibility
Opened this issue ยท 3 comments
๐ Feature
Making the ot.sinkhorn2
function compatible with functorch.vmap
.
Motivation
I'm using the Python Optimal Transport
library. I want to define a loss function that iterates over every sample in my batch and calculates the sinkhorn
distance for that sample and its ground-truth value. What I was using before was a for-loop:
for i in range(len(P_batch)):
if i == 0:
loss = ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
loss += ot.sinkhorn2(P_batch[i].view(-1, 1), Q_batch[i].view(-1, 1), C, epsilon)
but this is way too slow for my application. I was reading through functorch
, and apparently I should have been able to use the vmap
functionality.
losses = vmap(ot.sinkhorn2)(P, Q, C, epsilon)
But after wrapping my function in vmap
, I get this weird error:
File /anaconda3/envs/my_env/lib/python3.8/site-packages/ot/bregman.py:505, in sinkhorn_knopp(a, b, M, reg, numItermax, stopThr, verbose, log, warn, warmstart, **kwargs)
502 v = b / KtransposeU
503 u = 1. / nx.dot(Kp, v)
--> 505 if (nx.any(KtransposeU == 0)
506 or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
507 or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
508 # we have reached the machine precision
509 # come back to previous solution and quit loop
510 warnings.warn('Warning: numerical errors at iteration %d' % ii)
511 u = uprev
RuntimeError: vmap: It looks like you're attempting to use a Tensor in some data-dependent control flow. We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
Pitch
Apparently, the data-dependent if-statement
needs to be replaced with other alternatives. Any help is appreciated.
That is a good point but POT is implemeted in pure python with backend and geting rid tof conditional flows is going to be a pain.
Note that for what you want to compute (P sinkhorn in paralell with the same cost C) one does not need to do a loop/vmap and the sinkhorns can be impelmmented with already paralell matrix products with very little change in the sinkhorn_knopp function. We do not provide it in POT (maybe we will one day but we need to find the proper API) but feel free to reach me in the POT slack if you want some pointers.
Thanks @rflamary! I wanted to join the POT Slack, but unfortunately it seems that the workspace invite link hasn't been shared. Could you send me the POT Slack invite? Thanks.
Here is the invite link:
https://join.slack.com/t/pot-toolbox/shared_invite/zt-2se6yon8-pbj85t_QZBKqce0xkt1rbg