PythonOT/POT

CUDA error: initialization error

simon-forb opened this issue · 1 comments

Describe the bug

I get the following error:

my_conda_env/python3.11/site-packages/ot/backend.py", line 1822, in __init__
    self.rng_cuda_ = torch.Generator("cuda")
                     ^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: initialization error

I am using pot=0.9.4=py311h14de704_0 and pytorch=2.2.2=py3.11_cuda12.1_cudnn8.9.2_0.

To Reproduce

My trace looks like this (inside OT):

  File "lib/python3.11/site-packages/ot/bregman/_sinkhorn.py", line 1037, in sinkhorn_stabilized
    a, b, M = list_to_array(a, b, M)
              ^^^^^^^^^^^^^^^^^^^^^^
  File "python3.11/site-packages/ot/utils.py", line 68, in list_to_array
    nx = get_backend(*lst_not_empty)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python3.11/site-packages/ot/backend.py", line 222, in get_backend
    return _get_backend_instance(backend_impl)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "python3.11/site-packages/ot/backend.py", line 170, in _get_backend_instance
    _BACKENDS[backend_impl.__name__] = backend_impl()
                                       ^^^^^^^^^^^^^^
  File "python3.11/site-packages/ot/backend.py", line 1822, in __init__
    self.rng_cuda_ = torch.Generator("cuda")

Code sample

device = torch.device("cuda")
n1 = 10
rank = 5
p1 = torch.ones(n1, device=device) / n1

Q = torch.rand(n1, rank, device=device)
Q = ot.bregman.sinkhorn_stabilized(
         p1, torch.ones(rank, device=device) / rank, Q, reg=1e-1
)

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Ubuntu 22
  • Python version: 3.11
  • How was POT installed (source, pip, conda): conda
  • Only for GPU related bugs:
    • CUDA version: 12.1
    • GPU models and configuration: NVIDIA GeForce GTX 1650
    • Any other relevant information:

Output of the following code snippet:

>>> import platform; print(platform.platform())
Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
>>> import sys; print("Python", sys.version)
Python 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0]
>>> import numpy; print("NumPy", numpy.__version__)
NumPy 1.24.3
>>> import scipy; print("SciPy", scipy.__version__)
SciPy 1.14.0
>>> import ot; print("POT", ot.__version__)
POT 0.9.4