CUDA out of memory when using ot.sinkhorn2 as a loss function
Closed this issue · 4 comments
Hi, I'm trying to implement emd/sinkhorn distance as the loss function for 2D matrices.
However, ot.sinkhorn2
causes CUDA out of memory error when it's being computed:
The ot.emd2
can also give this error when I use a larger data set.
By viewing nvidia-smi
, when I try to train the same dataset, ot.emd2
uses up about 7G/12G memory (which is fine), while ot.sinkhorn2
uses 11G/12G memory and causes the error above.
Thank you!
Hello,
This probably comes form the fact that torch needs to keep the intermediate values in sinkhorn in memory to allow for a backward(). I have been planning to add a detach_iterations parameter for a hile that will allow to run the whole algorithm and plug the gradient assuming convrgence (implicit differettiation) for a while. I will get to that when I have more time.
Could you please check if sinkhorn2 explodes in memory when you give it arrays with keep_gradient=False?
Hello, I have added an option to ot.solve
that forces the use of implicit gradients and limits the memory use:
you can compute the sinkhorn loss with
loss = ot.solve(M, a, b, reg=1, grad='implicit').value
All iterations are detached and the gradient is set at the end with no memory overhead but then it is differentiable only wrt value (not value_linear or the OT plan). Could you tell me if it solves your problem, it is merges in master brach ?
This probably comes form the fact that torch _needs to keep the intermediate values in sinkhorn in memory to allow for a backward(). I have been planning to add a detach_iterations parameter for a hile that will allow to run the whole algorithm and plug the gradient assuming convrgence - [ ] # (implicit differettiation) for a while. I will` get to that when I have more time.**
This new functionality is now available in ot.solve
and ot.solve_sample
thanks to PR #605 .