PythonOT/POT

CUDA out of memory when using ot.sinkhorn2 as a loss function

Closed this issue · 4 comments

Hi, I'm trying to implement emd/sinkhorn distance as the loss function for 2D matrices.
However, ot.sinkhorn2 causes CUDA out of memory error when it's being computed:

image

The ot.emd2 can also give this error when I use a larger data set.

By viewing nvidia-smi, when I try to train the same dataset, ot.emd2 uses up about 7G/12G memory (which is fine), while ot.sinkhorn2 uses 11G/12G memory and causes the error above.

Thank you!

Hello,

This probably comes form the fact that torch needs to keep the intermediate values in sinkhorn in memory to allow for a backward(). I have been planning to add a detach_iterations parameter for a hile that will allow to run the whole algorithm and plug the gradient assuming convrgence (implicit differettiation) for a while. I will get to that when I have more time.

Could you please check if sinkhorn2 explodes in memory when you give it arrays with keep_gradient=False?

Hello, I have added an option to ot.solvethat forces the use of implicit gradients and limits the memory use:

you can compute the sinkhorn loss with

loss = ot.solve(M, a, b, reg=1, grad='implicit').value

All iterations are detached and the gradient is set at the end with no memory overhead but then it is differentiable only wrt value (not value_linear or the OT plan). Could you tell me if it solves your problem, it is merges in master brach ?

This probably comes form the fact that torch _needs to keep the intermediate values in sinkhorn in memory to allow for a backward(). I have been planning to add a detach_iterations parameter for a hile that will allow to run the whole algorithm and plug the gradient assuming convrgence - [ ] # (implicit differettiation) for a while. I will` get to that when I have more time.**

This new functionality is now available in ot.solve and ot.solve_sample thanks to PR #605 .