CUDA error: initialization error
simon-forb opened this issue · 1 comments
simon-forb commented
Describe the bug
I get the following error:
my_conda_env/python3.11/site-packages/ot/backend.py", line 1822, in __init__
self.rng_cuda_ = torch.Generator("cuda")
^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: initialization error
I am using pot=0.9.4=py311h14de704_0
and pytorch=2.2.2=py3.11_cuda12.1_cudnn8.9.2_0
.
To Reproduce
My trace looks like this (inside OT):
File "lib/python3.11/site-packages/ot/bregman/_sinkhorn.py", line 1037, in sinkhorn_stabilized
a, b, M = list_to_array(a, b, M)
^^^^^^^^^^^^^^^^^^^^^^
File "python3.11/site-packages/ot/utils.py", line 68, in list_to_array
nx = get_backend(*lst_not_empty)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python3.11/site-packages/ot/backend.py", line 222, in get_backend
return _get_backend_instance(backend_impl)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "python3.11/site-packages/ot/backend.py", line 170, in _get_backend_instance
_BACKENDS[backend_impl.__name__] = backend_impl()
^^^^^^^^^^^^^^
File "python3.11/site-packages/ot/backend.py", line 1822, in __init__
self.rng_cuda_ = torch.Generator("cuda")
Code sample
device = torch.device("cuda")
n1 = 10
rank = 5
p1 = torch.ones(n1, device=device) / n1
Q = torch.rand(n1, rank, device=device)
Q = ot.bregman.sinkhorn_stabilized(
p1, torch.ones(rank, device=device) / rank, Q, reg=1e-1
)
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Ubuntu 22
- Python version: 3.11
- How was POT installed (source,
pip
,conda
): conda - Only for GPU related bugs:
- CUDA version: 12.1
- GPU models and configuration: NVIDIA GeForce GTX 1650
- Any other relevant information:
Output of the following code snippet:
>>> import platform; print(platform.platform())
Linux-5.15.153.1-microsoft-standard-WSL2-x86_64-with-glibc2.35
>>> import sys; print("Python", sys.version)
Python 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0]
>>> import numpy; print("NumPy", numpy.__version__)
NumPy 1.24.3
>>> import scipy; print("SciPy", scipy.__version__)
SciPy 1.14.0
>>> import ot; print("POT", ot.__version__)
POT 0.9.4