PythonOT/POT

ot.solve uses GPU even though tensors are on CPU?

mathurinm opened this issue · 3 comments

Describe the bug

Running ot.solve with tensors on the CPU allows memory on the GPU (thisis documented in get_backend_list) but also seems to use the GPU, as the Watts are increasing. See attached screencast :
Screencast from 08-03-2024 11:24:44.webm

Is it normal?

Script

import torch
import ot

n_samples = 5_000

x = torch.randn(n_samples, 2)
y = torch.randn(n_samples, 2)

a = torch.rand(n_samples)
a /= a.sum()
b = torch.rand(n_samples)
b /= b.sum()

M = ot.dist(x, y)

res = ot.solve(M, a, b, reg=0.1, reg_type="entropy")

Hello @mathurinm ,

It might relate to the following closed issue #516 , are you using POT >= 0.9.2 ?

Thanks for the quick reply Cedric, this is happening with the latest dev version :

In [1]: import ot
ot.__version__


In [2]: ot.__version__
Out[2]: '0.9.3dev'

Contrary to #516, the memory consumption starts when calling a function from ot, like dist (when the backend is determined, I guess). Importing ot does not use the GPU indeed.

Let me know if I can provide additional info

Ok thank you for the feedback, we will look into this and go back to you asap.