jax-ml/jax

automatic detection for GPU pip install doesn't quite work on ubuntu 20.04

kratsg opened this issue · 4 comments

As stated on the README,

Please let us know on the issue tracker if you run into any errors or problems with the prebuilt wheels.

I tried to pip install via

pip install --upgrade https://storage.googleapis.com/jax-releases/`nvidia-smi | sed -En "s/.* CUDA Version: ([0-9]*)\.([0-9]*).*/cuda\1\2/p"`/jaxlib-0.1.52-`python3 -V | sed -En "s/Python ([0-9]*)\.([0-9]*).*/cp\1\2/p"`-none-manylinux2010_x86_64.whl jax

but I've noted that nvidia-smi and nvidia-cuda-toolkit are using (slightly) different versions which seems to be "ok" as one is a runtime driver and the other is a gpu driver (probably?).

This generally causes jax to complain about looking for 10.2 when it can't find it (because it grabs the version number from nvidia-smi instead of from nvcc for the automated install). See below:

  • nvcc --version has the version from the nvidia-cuda-toolkit (10.1.243)
  • nvidia-smi has the version from the GPU driver sudo ubuntu-drivers autoinstall (10.2)
$ sudo apt show  nvidia-cuda-toolkit
Package: nvidia-cuda-toolkit
Version: 10.1.243-3
Priority: extra
Section: multiverse/devel
Origin: Ubuntu
$ nvidia-smi
Thu Aug  6 18:06:16 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.100      Driver Version: 440.100      CUDA Version: 10.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  GeForce RTX 208...  Off  | 00000000:26:00.0 Off |                  N/A |
| 41%   29C    P8     1W / 260W |     74MiB / 11016MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      1296      G   /usr/lib/xorg/Xorg                            56MiB |
|    0      1467      G   /usr/bin/gnome-shell                          16MiB |
+-----------------------------------------------------------------------------+
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Sun_Jul_28_19:07:16_PDT_2019
Cuda compilation tools, release 10.1, V10.1.243
$ ubuntu-drivers devices
== /sys/devices/pci0000:00/0000:00:03.1/0000:26:00.0 ==
modalias : pci:v000010DEd00001E07sv000010DEsd000012A4bc03sc00i00
vendor   : NVIDIA Corporation
model    : TU102 [GeForce RTX 2080 Ti Rev. A]
driver   : nvidia-driver-418-server - distro non-free
driver   : nvidia-driver-440-server - distro non-free
driver   : nvidia-driver-435 - distro non-free
driver   : nvidia-driver-440 - distro non-free recommended
driver   : xserver-xorg-video-nouveau - distro free builtin
$ ls -lavh /usr/lib/x86_64-linux-gnu/libcuda*
lrwxrwxrwx 1 root root   12 May 29 03:14 /usr/lib/x86_64-linux-gnu/libcuda.so -> libcuda.so.1
-rw-r--r-- 1 root root 703K Aug  9  2019 /usr/lib/x86_64-linux-gnu/libcudadevrt.a
lrwxrwxrwx 1 root root   17 Apr 11 05:56 /usr/lib/x86_64-linux-gnu/libcudart.so -> libcudart.so.10.1
lrwxrwxrwx 1 root root   21 Apr 11 05:56 /usr/lib/x86_64-linux-gnu/libcudart.so.10.1 -> libcudart.so.10.1.243
-rw-r--r-- 1 root root 493K Aug  9  2019 /usr/lib/x86_64-linux-gnu/libcudart.so.10.1.243
-rw-r--r-- 1 root root 868K Aug  9  2019 /usr/lib/x86_64-linux-gnu/libcudart_static.a
lrwxrwxrwx 1 root root   18 May 29 03:14 /usr/lib/x86_64-linux-gnu/libcuda.so.1 -> libcuda.so.440.100
-rw-r--r-- 1 root root  17M May 29 01:32 /usr/lib/x86_64-linux-gnu/libcuda.so.440.100

has anyone solved this issue yet? Jax doesn't detect my 2080 Ti, even though tensorflow finds it.

skye commented

@kratsg thanks for the detailed report! We could change the command to use the output of nvcc --version, but any reason you can't update your cuda toolkit to 10.2? :) AFAIK those versions should not be out of sync...

@robertalanm are you sure you're experiencing the same issue, i.e. the automatic detection command isn't working? I suggest opening a new issue with more information, e.g. how you're installing jaxlib + jax, the output of that, and/or any error messages you're getting.

@kratsg thanks for the detailed report! We could change the command to use the output of nvcc --version, but any reason you can't update your cuda toolkit to 10.2? :) AFAIK those versions should not be out of sync...

This is the part that bothers me a littttttle bit (but it doesn't break as the API is compatible). However Ubuntu 20.04 with nvidia-cuda-toolkit only ships with 10.1 (no 10.2). In order to get the latest one, I'd have to manually add the CUDA repos and grab from there.

skye commented

Ah, I didn't realize the Nvidia driver allows for older toolkit versions. We should use the toolkit version in the auto install command then. Please feel free to submit a PR if you have something that works, otherwise I can change it tomorrow.