ot.solve uses GPU even though tensors are on CPU?
mathurinm opened this issue · 3 comments
Describe the bug
Running ot.solve
with tensors on the CPU allows memory on the GPU (thisis documented in get_backend_list
) but also seems to use the GPU, as the Watts are increasing. See attached screencast :
Screencast from 08-03-2024 11:24:44.webm
Is it normal?
Script
import torch
import ot
n_samples = 5_000
x = torch.randn(n_samples, 2)
y = torch.randn(n_samples, 2)
a = torch.rand(n_samples)
a /= a.sum()
b = torch.rand(n_samples)
b /= b.sum()
M = ot.dist(x, y)
res = ot.solve(M, a, b, reg=0.1, reg_type="entropy")
Hello @mathurinm ,
It might relate to the following closed issue #516 , are you using POT >= 0.9.2 ?
Thanks for the quick reply Cedric, this is happening with the latest dev version :
In [1]: import ot
ot.__version__
In [2]: ot.__version__
Out[2]: '0.9.3dev'
Contrary to #516, the memory consumption starts when calling a function from ot
, like dist
(when the backend is determined, I guess). Importing ot
does not use the GPU indeed.
Let me know if I can provide additional info
Ok thank you for the feedback, we will look into this and go back to you asap.