JAX and TORCH
Closed this issue ยท 33 comments
Description
When I only pip the latesd jax with cuda(pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html), I can use the jax with gpu.But when I pip the torch(pip install torch) later, Ican't use the jax with gpu,it remind me that cuda or cusolver's version is older than jax's.Why? Can Older jax version avoid it?Then how can I pip the jax[cuda] with relevant version?
What jax/jaxlib version are you using?
jax-0.4.18 jaxlib-0.4.18+cuda12.cudnn89
Which accelerator(s) are you using?
GPU
Additional system info
3.10.9/Linux
NVIDIA GPU info
No response
That's correct. The current releases of PyTorch and JAX have incompatible CUDA version dependencies.
I reported this issue to the PyTorch developers a while back, but there has been no interest in relaxing their CUDA version dependencies.
My recommendations:
- use a different virtualenv for PyTorch and JAX. This is the simplest solution and probably the best.
- if for some reason you really want PyTorch and JAX in the same virtualenv, install the CPU version of one of them and the CUDA version of the other. That avoids any CUDA version conflicts.
- it may work to simply install JAX after PyTorch, since JAX wants a newer CUDA version than PyTorch's current release does, and in practice NVIDIA's CUDA releases are backwards compatible. I'm not sure if PyTorch enforces a version check, but if not it's highly likely this will work. But I don't think the PyTorch developers support it.
- another solution is to install the CUDA version needed for one of the two, and build the other one from source. For example, JAX will happily build from source with an older CUDA release, it's just the binary distribution that requires a CUDA version matching the version against which it was built.
Does that resolve your problem?
Hope that helps!
This is quite annoying (and inconvenient) now that people have written torch2jax functionality which allows GPU-accelerated interaction,
https://github.com/samuela/torch2jax
https://github.com/rdyro/torch2jax
Hi @ywsslr, I've been experimenting the simultaneous usage of Torch and JAX for a while. I'm currently working in a Docker container in which they both work on GPU.
JAX was installed according to the official documentation as:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I leave here the Conda YAML of the environment, there will probably be some extra packages, but I hope this can help:
conda environment
name: base
channels:
- nvidia
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- absl-py=1.4.0=py310h06a4308_0
- alsa-lib=1.2.10=hd590300_0
- appdirs=1.4.4=pyhd3eb1b0_0
- asttokens=2.0.5=pyhd3eb1b0_0
- attr=2.5.1=h166bdaf_1
- backcall=0.2.0=pyhd3eb1b0_0
- binutils=2.40=hdd6e379_0
- binutils_impl_linux-64=2.40=hf600244_0
- binutils_linux-64=2.40=hbdbef99_2
- blas=1.0=openblas
- boltons=23.0.0=pyhd8ed1ab_0
- brotli=1.0.9=he6710b0_2
- brotli-python=1.1.0=py310hc6cd4ac_0
- bzip2=1.0.8=h7f98852_4
- c-ares=1.19.1=hd590300_0
- c-compiler=1.6.0=hd590300_0
- ca-certificates=2023.7.22=hbcca054_0
- cairo=1.16.0=hb05425b_5
- certifi=2023.7.22=pyhd8ed1ab_0
- cffi=1.15.1=py310h255011f_3
- charset-normalizer=3.2.0=pyhd8ed1ab_0
- chex=0.1.5=py310h06a4308_0
- click=8.0.4=py310h06a4308_0
- colorama=0.4.6=pyhd8ed1ab_0
- coloredlogs=15.0.1=py310h06a4308_1
- compilers=1.6.0=ha770c72_0
- conda=23.3.1=py310hff52083_0
- conda-package-handling=2.2.0=pyh38be061_0
- conda-package-streaming=0.9.0=pyhd8ed1ab_0
- contourpy=1.0.5=py310hdb19cb5_0
- cryptography=41.0.3=py310h75e40e8_0
- cuda-nvcc=11.3.58=h2467b9f_0
- cuda-version=11.8=h70ddcb2_2
- cudatoolkit=11.8.0=h4ba93d1_12
- cudnn=8.9.2.26=cuda11_0
- cupti=11.8.0=he078b1a_0
- cxx-compiler=1.6.0=h00ab1b0_0
- cycler=0.11.0=pyhd3eb1b0_0
- dbus=1.13.18=hb2f20db_0
- deap=1.4.1=py310h7cbd5c2_0
- decorator=5.1.1=pyhd3eb1b0_0
- dm-tree=0.1.7=py310h6a678d5_1
- docker-pycreds=0.4.0=pyhd3eb1b0_0
- docstring_parser=0.15=pyhd8ed1ab_0
- exceptiongroup=1.0.4=py310h06a4308_0
- executing=0.8.3=pyhd3eb1b0_0
- expat=2.5.0=h6a678d5_0
- filelock=3.9.0=py310h06a4308_0
- flax=0.6.1=pyhd8ed1ab_1
- fmt=9.1.0=h924138e_0
- font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
- font-ttf-inconsolata=2.001=hcb22688_0
- font-ttf-source-code-pro=2.030=hd3eb1b0_0
- font-ttf-ubuntu=0.83=h8b1ccd4_0
- fontconfig=2.14.2=h14ed4e7_0
- fonts-anaconda=1=h8fa9717_0
- fonts-conda-ecosystem=1=hd3eb1b0_0
- fonttools=4.25.0=pyhd3eb1b0_0
- fortran-compiler=1.6.0=heb67821_0
- freetype=2.12.1=h4a9f257_0
- frozendict=2.3.8=py310h2372a71_0
- gcc=12.3.0=h8d2909c_2
- gcc_impl_linux-64=12.3.0=he2b93b0_1
- gcc_linux-64=12.3.0=h76fc315_2
- gettext=0.21.1=h27087fc_0
- gfortran=12.3.0=h499e0f7_2
- gfortran_impl_linux-64=12.3.0=hfcedea8_1
- gfortran_linux-64=12.3.0=h7fe76b4_2
- gitdb=4.0.7=pyhd3eb1b0_0
- gitpython=3.1.30=py310h06a4308_0
- glib=2.78.0=hfc55251_0
- glib-tools=2.78.0=hfc55251_0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py310heeb90bb_0
- graphite2=1.3.14=h295c915_1
- gst-plugins-base=1.22.5=h8e1006c_1
- gstreamer=1.22.5=h98fc4e7_1
- gxx=12.3.0=h8d2909c_2
- gxx_impl_linux-64=12.3.0=he2b93b0_1
- gxx_linux-64=12.3.0=h8a814eb_2
- harfbuzz=8.2.0=h3d44ed6_0
- humanfriendly=10.0=py310h06a4308_1
- icu=73.2=h59595ed_0
- idna=3.4=pyhd8ed1ab_0
- intel-openmp=2023.1.0=hdb19cb5_46305
- ipython=8.15.0=py310h06a4308_0
- jax-dataclasses=1.5.1=pyhd8ed1ab_0
- jaxlie=1.3.3=pyhd8ed1ab_0
- jedi=0.18.1=py310h06a4308_1
- jinja2=3.1.2=py310h06a4308_0
- jsonpatch=1.32=pyhd8ed1ab_0
- jsonpointer=2.4=py310hff52083_0
- kernel-headers_linux-64=2.6.32=he073ed8_16
- keyutils=1.6.1=h166bdaf_0
- kiwisolver=1.4.4=py310h6a678d5_0
- krb5=1.21.2=h659d440_0
- lame=3.100=h7b6447c_0
- lcms2=2.15=h7f713cb_2
- ld_impl_linux-64=2.40=h41732ed_0
- lerc=4.0.0=h27087fc_0
- libarchive=3.6.2=h039dbb9_1
- libcap=2.69=h0f662aa_0
- libclang=15.0.7=default_h7634d5b_3
- libclang13=15.0.7=default_h9986a30_3
- libcups=2.3.3=h4637d8d_4
- libcurl=8.3.0=hca28451_0
- libdeflate=1.19=hd590300_0
- libedit=3.1.20191231=he28a2e2_2
- libev=4.33=h516909a_1
- libevent=2.1.12=hdbd6064_1
- libexpat=2.5.0=hcb278e6_1
- libffi=3.4.2=h7f98852_5
- libflac=1.4.3=h59595ed_0
- libgcc-devel_linux-64=12.3.0=h8bca6fd_1
- libgcc-ng=13.2.0=h807b86a_0
- libgcrypt=1.10.1=h166bdaf_0
- libgfortran-ng=13.2.0=h69a702a_1
- libgfortran5=13.2.0=ha4646dd_1
- libglib=2.78.0=hebfc3b9_0
- libgomp=13.2.0=h807b86a_0
- libgpg-error=1.47=h71f35ed_0
- libiconv=1.17=h166bdaf_0
- libjpeg-turbo=2.1.5.1=hd590300_1
- libllvm15=15.0.7=h5cf9203_3
- libmamba=1.2.0=hcea66bb_0
- libmambapy=1.2.0=py310h1428755_0
- libnghttp2=1.52.0=h61bc06f_0
- libnsl=2.0.0=h7f98852_0
- libogg=1.3.5=h27cfd23_1
- libopenblas=0.3.21=h043d6bf_0
- libopus=1.3.1=h7b6447c_0
- libpng=1.6.39=h5eee18b_0
- libpq=15.4=hfc447b1_0
- libprotobuf=3.20.3=he621ea3_0
- libsanitizer=12.3.0=h0f45ef3_1
- libsndfile=1.2.2=hbc2eb40_0
- libsolv=0.7.24=hfc55251_4
- libsqlite=3.43.0=h2797004_0
- libssh2=1.11.0=h0841786_0
- libstdcxx-devel_linux-64=12.3.0=h8bca6fd_1
- libstdcxx-ng=13.2.0=h7e041cc_0
- libsystemd0=254=h3516f8a_0
- libtiff=4.6.0=h29866fb_1
- libuuid=2.38.1=h0b41bf4_0
- libvorbis=1.3.7=h7b6447c_0
- libwebp-base=1.3.2=h5eee18b_0
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.5.0=h5d7e998_3
- libxml2=2.11.5=h232c23b_1
- libzlib=1.2.13=hd590300_5
- lz4-c=1.9.4=hcb278e6_0
- lzo=2.10=h516909a_1000
- magma=2.7.1=h2c23e93_0
- mamba=1.2.0=py310h51d5547_0
- markdown-it-py=2.2.0=py310h06a4308_1
- markupsafe=2.1.1=py310h7f8727e_0
- mashumaro=3.6=py310h06a4308_0
- matplotlib=3.7.2=py310h06a4308_0
- matplotlib-base=3.7.2=py310h1128e8f_0
- matplotlib-inline=0.1.6=py310h06a4308_0
- mdurl=0.1.0=py310h06a4308_0
- mkl=2023.1.0=h213fc3f_46343
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpg123=1.31.3=hcb278e6_0
- mpmath=1.3.0=py310h06a4308_0
- msgpack-python=1.0.3=py310hd09550d_0
- munkres=1.1.4=py_0
- mysql-common=8.0.33=hf1915f5_4
- mysql-libs=8.0.33=hca2cd23_4
- ncurses=6.4=hcb278e6_0
- networkx=3.1=py310h06a4308_0
- ninja=1.10.2=h06a4308_5
- ninja-base=1.10.2=hd09550d_5
- nspr=4.35=h6a678d5_0
- nss=3.92=h1d7d5a4_0
- numpy=1.25.2=py310heeff2f4_0
- numpy-base=1.25.2=py310h8a23956_0
- openjpeg=2.5.0=h488ebb8_3
- openssl=3.1.2=hd590300_0
- opt_einsum=3.3.0=pyhd3eb1b0_1
- optax=0.1.4=py310h06a4308_0
- overrides=7.4.0=pyhd8ed1ab_0
- packaging=23.1=pyhd8ed1ab_0
- parso=0.8.3=pyhd3eb1b0_0
- pathtools=0.1.2=pyhd3eb1b0_1
- pcre2=10.40=hc3806b6_0
- pexpect=4.8.0=pyhd3eb1b0_3
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=10.0.1=py310h29da1c1_0
- pip=23.2.1=pyhd8ed1ab_0
- pixman=0.40.0=h7f8727e_1
- pluggy=1.3.0=pyhd8ed1ab_0
- ply=3.11=py310h06a4308_0
- pptree=3.1=pyhd8ed1ab_0
- prompt-toolkit=3.0.36=py310h06a4308_0
- protobuf=3.20.3=py310h6a678d5_0
- psutil=5.9.0=py310h5eee18b_0
- ptyprocess=0.7.0=pyhd3eb1b0_2
- pulseaudio-client=16.1=hb77b528_5
- pure_eval=0.2.2=pyhd3eb1b0_0
- pybind11-abi=4=hd8ed1ab_3
- pycosat=0.6.4=py310h5764c6d_1
- pycparser=2.21=pyhd8ed1ab_0
- pygments=2.15.1=py310h06a4308_1
- pyopenssl=23.2.0=pyhd8ed1ab_1
- pyparsing=3.0.9=py310h06a4308_0
- pyqt=5.15.9=py310h04931ad_4
- pyqt5-sip=12.12.2=py310hc6cd4ac_4
- pysocks=1.7.1=pyha2e5f31_6
- python=3.10.8=h4a9ceb5_0_cpython
- python-dateutil=2.8.2=pyhd3eb1b0_0
- python_abi=3.10=3_cp310
- pytorch=2.0.1=gpu_cuda118py310h7799f5a_0
- pyyaml=6.0=py310h5eee18b_1
- qt-main=5.15.8=hc47bfe8_16
- readline=8.2=h8228510_1
- reproc=14.2.4=h0b41bf4_0
- reproc-cpp=14.2.4=hcb278e6_0
- requests=2.31.0=pyhd8ed1ab_0
- rich=13.3.5=py310h06a4308_0
- ruamel.yaml=0.17.32=py310h2372a71_0
- ruamel.yaml.clib=0.2.7=py310h1fa729e_1
- scipy=1.11.1=py310heeff2f4_0
- sentry-sdk=1.9.0=py310h06a4308_0
- setproctitle=1.2.2=py310h7f8727e_0
- setuptools=68.2.2=pyhd8ed1ab_0
- shtab=1.6.4=pyhd8ed1ab_1
- sip=6.7.11=py310hc6cd4ac_0
- six=1.16.0=pyhd3eb1b0_1
- smmap=4.0.0=pyhd3eb1b0_0
- stack_data=0.2.0=pyhd3eb1b0_0
- sympy=1.11.1=py310h06a4308_0
- sysroot_linux-64=2.12=he073ed8_16
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.12=h27826a3_0
- toml=0.10.2=pyhd3eb1b0_0
- tomli=2.0.1=py310h06a4308_0
- toolz=0.12.0=pyhd8ed1ab_0
- tornado=6.3.2=py310h5eee18b_0
- tqdm=4.66.1=pyhd8ed1ab_0
- traitlets=5.7.1=py310h06a4308_0
- typing-extensions=4.7.1=py310h06a4308_0
- typing_extensions=4.7.1=py310h06a4308_0
- typing_utils=0.1.0=pyhd8ed1ab_0
- tyro=0.5.7=pyhd8ed1ab_0
- tzdata=2023c=h71feb2d_0
- urllib3=2.0.4=pyhd8ed1ab_0
- wandb=0.15.10=pyhd8ed1ab_0
- wcwidth=0.2.5=pyhd3eb1b0_0
- wheel=0.41.2=pyhd8ed1ab_0
- xcb-util=0.4.0=hd590300_1
- xcb-util-image=0.4.0=h8ee46fc_1
- xcb-util-keysyms=0.4.0=h8ee46fc_1
- xcb-util-renderutil=0.3.9=hd590300_1
- xcb-util-wm=0.4.1=h8ee46fc_1
- xkeyboard-config=2.39=hd590300_0
- xorg-kbproto=1.0.7=h7f98852_1002
- xorg-libice=1.1.1=hd590300_0
- xorg-libsm=1.2.4=h7391055_0
- xorg-libx11=1.8.6=h8ee46fc_0
- xorg-libxau=1.0.11=hd590300_0
- xorg-libxext=1.3.4=h0b41bf4_2
- xorg-libxrender=0.9.11=hd590300_0
- xorg-renderproto=0.11.1=h7f98852_1002
- xorg-xextproto=7.3.0=h0b41bf4_1003
- xorg-xf86vidmodeproto=2.3.1=h7f98852_1002
- xorg-xproto=7.0.31=h27cfd23_1007
- xz=5.2.6=h166bdaf_0
- yaml=0.2.5=h7b6447c_0
- yaml-cpp=0.7.0=h27087fc_2
- zlib=1.2.13=hd590300_5
- zstandard=0.19.0=py310h5764c6d_0
- zstd=1.5.5=hfc55251_0
- pip:
- jax==0.4.18
- jaxlib==0.4.18+cuda12.cudnn89
- ml-dtypes==0.3.1
- nvidia-cublas-cu12==12.2.5.6
- nvidia-cuda-cupti-cu12==12.2.142
- nvidia-cuda-nvcc-cu12==12.2.140
- nvidia-cuda-nvrtc-cu12==12.2.140
- nvidia-cuda-runtime-cu12==12.2.140
- nvidia-cudnn-cu12==8.9.4.25
- nvidia-cufft-cu12==11.0.8.103
- nvidia-cusolver-cu12==11.5.2.141
- nvidia-cusparse-cu12==12.1.2.141
- nvidia-nccl-cu12==2.18.3
- nvidia-nvjitlink-cu12==12.2.140
prefix: /conda
- Python 3.10.8
- Ubuntu 22.04
- jax==0.4.18
- jaxlib==0.4.18+cuda12.cudnn89
- Driver Version: 525.125.06
- CUDA Version: 12.0
Thank you for your all help. For some reason I can't experience it now,but I'll try it soon and reply you.
ok people, this has been a 1 day nightmare. But finally got this to work on an H100 machine with cuda 12.2, without sudo.
-
First install Cuda 12.2 (this was already there for me)
-
Then install Cudnn 8.9 through the official website, using the tar option: https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar
-
then follow what this guy did to build magma: huggingface/autotrain-advanced#281 (comment)
then install pytorch from source as that post says!!!! and bualaaa
No promises, but informally we're going to try to keep at least one JAX release have a version that is also released with PyTorch. Right now, that's the CUDA 11.8 release of JAX.
It's not a guarantee, though; it might happen that for some JAX and Pytorch versions there's no intersecting CUDA version.
I hit a similar issue when installing pytorch and jax into the same conda environment: when torch is loaded first, jax.devices()
will list only cpu devices.
A short summary of diagnosis: It turns out that torch is built against cudnn version 8.7 while jaxlib is built against cudnn version 8.8 leading to an exception when executing jax._src.xla_bridge._check_cuda_versions()
.
Here follows a reproducer:
mamba create -n test-pytorch-jax pytorch pytorch-cuda=11.8 jaxlib=*=*cuda118* jax -c pytorch -c nvidia --no-channel-priority -y
mamba activate test-pytorch-jax
(note: using strict channel priority would lead to a mamba solver problem).
Import torch before checking jax.devices:
>>> import torch
>>> import jax
>>> jax.devices()
CUDA backend failed to initialize: Found cuDNN version 8700, but JAX was built against version 8800, which is newer. The copy of cuDNN that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
Import torch after checking jax.devices:
>>> import jax
>>> jax.devices()
[cuda(id=0), cuda(id=1)]
>>> import torch
>>> jax.__version__
'0.4.23'
>>> torch.__version__
'2.1.2'
>>> from torch._C import _cudnn
>>> _cudnn.getCompileVersion()
(8, 7, 0)
Notices that the result of jaxlib.cuda._versions.cudnn_get_version()
depends on whether torch
was imported before or after calling jaxlib.cuda._versions.cudnn_get_version
:
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> import torch
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
vs
>>> import torch
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8700
that qualifies as an incompatible linkage issue: since libcudnn is dynamically loaded, the result of cudnnGetVersion ought to give the version of loaded library and not of the version of the library that a software was built against. The behavior above suggests that torch was linked with libcudnn statically.
A possible resolution: Note that cuDNN minor releases are backward compatible with applications built against the same or earlier minor release. Hence, as long as jaxlib and torch are built against libcudnn with the same major version (8), the jax version check ought to ignore cudnn minor versions. Here is a patch:
diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py
index 7977f6329..17c14bc5a 100644
--- a/jax/_src/xla_bridge.py
+++ b/jax/_src/xla_bridge.py
@@ -263,7 +263,7 @@ def _check_cuda_versions():
cuda_versions.cudnn_build_version,
# NVIDIA promise both backwards and forwards compatibility for cuDNN patch
# versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
- scale_for_comparison=100,
+ scale_for_comparison=1000,
)
_version_check("cuFFT", cuda_versions.cufft_get_version,
cuda_versions.cufft_build_version,
No promises, but informally we're going to try to keep at least one JAX release have a version that is also released with PyTorch. Right now, that's the CUDA 11.8 release of JAX.
The latest version pair I could find that were compatible with each other were jax[cuda11-pip,cuda11_pip]==0.4.10
and torch==2.2.1+cu118
. The main conflict in later versions for jax is for cudnn, which want >8.8
, but torch wants ==8.7
.
One way to check this would be:
cat > requirements.in <<EOF
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
--extra-index-url=https://download.pytorch.org/whl
jax[cuda11_pip]
torch==2.2.1+cu118
EOF
pip-compile
# Check the contents of requirements.txt.
A workaround that works better for us is to use CUDA 11 with Jax, but CUDA 12 with Torch. So basically jax[cuda11_pip]
and torch
in our requirements file works for us.
A workaround that works better for us is to use CUDA 11 with Jax, but CUDA 12 with Torch.
How did you get this to work? I'm using conda, but after installing pytorch-cuda=12.1
I get the following error from JAX:
E RuntimeError: Unable to initialize backend 'cuda': Unable to load CUDA. Is it installed? (set JAX_PLATFORMS='' to automatically choose an available backend)
We did not have to do anything special. Just installed the two packages in a clean env, and both worked.
The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:
mamba create -n jaxTorch jaxlib pytorch cuda-nvcc -c conda-forge -c nvidia -c pytorch
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> torch._C._cudnn.getCompileVersion()
(8, 9, 2)
conda list
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
_sysroot_linux-64_curr_repodata_hack 3 h69a702a_14 conda-forge
binutils_impl_linux-64 2.40 hf600244_0 conda-forge
binutils_linux-64 2.40 hdade7a5_3 conda-forge
blas 2.116 mkl conda-forge
blas-devel 3.9.0 16_linux64_mkl conda-forge
bzip2 1.0.8 hd590300_5 conda-forge
c-ares 1.27.0 hd590300_0 conda-forge
ca-certificates 2024.2.2 hbcca054_0 conda-forge
cuda-cccl_linux-64 12.1.109 ha770c72_0 conda-forge
cuda-cudart 12.1.105 hd3aeb46_0 conda-forge
cuda-cudart-dev 12.1.105 hd3aeb46_0 conda-forge
cuda-cudart-dev_linux-64 12.1.105 h59595ed_0 conda-forge
cuda-cudart-static 12.1.105 hd3aeb46_0 conda-forge
cuda-cudart-static_linux-64 12.1.105 h59595ed_0 conda-forge
cuda-cudart_linux-64 12.1.105 h59595ed_0 conda-forge
cuda-cupti 12.1.105 h59595ed_0 conda-forge
cuda-driver-dev_linux-64 12.1.105 h59595ed_0 conda-forge
cuda-libraries 12.1.0 0 nvidia
cuda-nvcc 12.1.105 hcdd1206_1 conda-forge
cuda-nvcc-dev_linux-64 12.1.105 ha770c72_0 conda-forge
cuda-nvcc-impl 12.1.105 hd3aeb46_0 conda-forge
cuda-nvcc-tools 12.1.105 hd3aeb46_0 conda-forge
cuda-nvcc_linux-64 12.1.105 h8a487aa_1 conda-forge
cuda-nvrtc 12.1.105 hd3aeb46_0 conda-forge
cuda-nvtx 12.1.105 h59595ed_0 conda-forge
cuda-opencl 12.1.105 h59595ed_0 conda-forge
cuda-runtime 12.1.0 0 nvidia
cuda-version 12.1 h1d6eff3_3 conda-forge
cudnn 8.9.7.29 h092f7fd_3 conda-forge
filelock 3.13.3 pyhd8ed1ab_0 conda-forge
gcc_impl_linux-64 12.3.0 he2b93b0_5 conda-forge
gcc_linux-64 12.3.0 h6477408_3 conda-forge
gxx_impl_linux-64 12.3.0 he2b93b0_5 conda-forge
gxx_linux-64 12.3.0 h4a1b8e8_3 conda-forge
icu 73.2 h59595ed_0 conda-forge
importlib-metadata 7.1.0 pyha770c72_0 conda-forge
importlib_metadata 7.1.0 hd8ed1ab_0 conda-forge
jax 0.4.25 pyhd8ed1ab_0 conda-forge
jaxlib 0.4.23 cuda120py312hc008a70_200 conda-forge
jinja2 3.1.3 pyhd8ed1ab_0 conda-forge
kernel-headers_linux-64 3.10.0 h4a8ded7_14 conda-forge
ld_impl_linux-64 2.40 h41732ed_0 conda-forge
libabseil 20240116.1 cxx17_h59595ed_2 conda-forge
libblas 3.9.0 16_linux64_mkl conda-forge
libcblas 3.9.0 16_linux64_mkl conda-forge
libcublas 12.1.0.26 0 nvidia
libcufft 11.0.2.4 0 nvidia
libcufile 1.6.1.9 hd3aeb46_0 conda-forge
libcurand 10.3.2.106 hd3aeb46_0 conda-forge
libcusolver 11.4.4.55 0 nvidia
libcusparse 12.0.2.55 0 nvidia
libexpat 2.6.2 h59595ed_0 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc-devel_linux-64 12.3.0 h8bca6fd_105 conda-forge
libgcc-ng 13.2.0 h807b86a_5 conda-forge
libgfortran-ng 13.2.0 h69a702a_5 conda-forge
libgfortran5 13.2.0 ha4646dd_5 conda-forge
libgomp 13.2.0 h807b86a_5 conda-forge
libgrpc 1.62.1 h15f2491_0 conda-forge
libhwloc 2.9.3 default_h554bfaf_1009 conda-forge
libiconv 1.17 hd590300_2 conda-forge
liblapack 3.9.0 16_linux64_mkl conda-forge
liblapacke 3.9.0 16_linux64_mkl conda-forge
libnpp 12.0.2.50 0 nvidia
libnsl 2.0.1 hd590300_0 conda-forge
libnvjitlink 12.1.105 hd3aeb46_0 conda-forge
libnvjpeg 12.1.1.14 0 nvidia
libprotobuf 4.25.3 h08a7969_0 conda-forge
libre2-11 2023.09.01 h5a48ba9_2 conda-forge
libsanitizer 12.3.0 h0f45ef3_5 conda-forge
libsqlite 3.45.2 h2797004_0 conda-forge
libstdcxx-devel_linux-64 12.3.0 h8bca6fd_105 conda-forge
libstdcxx-ng 13.2.0 h7e041cc_5 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libxml2 2.12.6 h232c23b_1 conda-forge
libzlib 1.2.13 hd590300_5 conda-forge
llvm-openmp 15.0.7 h0cdce71_0 conda-forge
markupsafe 2.1.5 py312h98912ed_0 conda-forge
mkl 2022.1.0 h84fe81f_915 conda-forge
mkl-devel 2022.1.0 ha770c72_916 conda-forge
mkl-include 2022.1.0 h84fe81f_915 conda-forge
ml_dtypes 0.3.2 py312hfb8ada1_0 conda-forge
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
nccl 2.20.5.1 h3a97aeb_0 conda-forge
ncurses 6.4.20240210 h59595ed_0 conda-forge
networkx 3.2.1 pyhd8ed1ab_0 conda-forge
numpy 1.26.4 py312heda63a1_0 conda-forge
ocl-icd 2.3.2 hd590300_1 conda-forge
openssl 3.2.1 hd590300_1 conda-forge
opt-einsum 3.3.0 hd8ed1ab_2 conda-forge
opt_einsum 3.3.0 pyhc1e730c_2 conda-forge
pip 24.0 pyhd8ed1ab_0 conda-forge
python 3.12.2 hab00c5b_0_cpython conda-forge
python_abi 3.12 4_cp312 conda-forge
pytorch 2.2.1 py3.12_cuda12.1_cudnn8.9.2_0 pytorch
pytorch-cuda 12.1 ha16c6d3_5 pytorch
pytorch-mutex 1.0 cuda pytorch
pyyaml 6.0.1 py312h98912ed_1 conda-forge
re2 2023.09.01 h7f4b329_2 conda-forge
readline 8.2 h8228510_1 conda-forge
scipy 1.12.0 py312heda63a1_2 conda-forge
setuptools 69.2.0 pyhd8ed1ab_0 conda-forge
sympy 1.12 pyh04b8f61_3 conda-forge
sysroot_linux-64 2.17 h4a8ded7_14 conda-forge
tbb 2021.11.0 h00ab1b0_1 conda-forge
tk 8.6.13 noxft_h4845f30_101 conda-forge
typing_extensions 4.10.0 pyha770c72_0 conda-forge
tzdata 2024a h0c530f3_0 conda-forge
wheel 0.43.0 pyhd8ed1ab_0 conda-forge
xz 5.2.6 h166bdaf_0 conda-forge
yaml 0.2.5 h7f98852_2 conda-forge
zipp 3.17.0 pyhd8ed1ab_0 conda-forge
fyi @traversaro
The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:
FYI, at the moment it is not possible to get both jax and pytorch with cuda 12 only using conda-forge dependencies for this reason (I pinned several dependencies to get a clearer error):
traversaro@IITBMP014LW012:~$ mamba create -n jaxtorchcuda pytorch==2.1.2=*cuda* jaxlib==0.4.23=*cuda* jax cuda-version=12.* python==3.11.* cudatoolkit==12.*
Looking for: ['pytorch==2.1.2[build=*cuda*]', 'jaxlib==0.4.23[build=*cuda*]', 'jax', 'cuda-version=12', 'python=3.11', 'cudatoolkit=12']
conda-forge/linux-64 Using cache
conda-forge/noarch Using cache
Could not solve for environment specs
The following packages are incompatible
โโ cuda-version 12** is installable with the potential options
โ โโ cuda-version [12.0|12.0.0] would require
โ โ โโ cudatoolkit 12.0|12.0.* , which can be installed;
โ โโ cuda-version 12.1 would require
โ โ โโ cudatoolkit 12.1|12.1.* , which can be installed;
โ โโ cuda-version 12.2 would require
โ โ โโ cudatoolkit 12.2|12.2.* , which can be installed;
โ โโ cuda-version 12.3 would require
โ โ โโ cudatoolkit 12.3|12.3.* , which can be installed;
โ โโ cuda-version 12.4 would require
โ โโ cudatoolkit 12.4|12.4.* , which can be installed;
โโ cudatoolkit 12** does not exist (perhaps a typo or a missing channel);
โโ jaxlib 0.4.23 *cuda* is installable with the potential options
โ โโ jaxlib 0.4.23 would require
โ โ โโ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
โ โโ jaxlib 0.4.23 would require
โ โ โโ libgrpc >=1.62.1,<1.63.0a0 , which requires
โ โ โโ libprotobuf >=4.25.3,<4.25.4.0a0 , which can be installed;
โ โโ jaxlib 0.4.23 would require
โ โ โโ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
โ โ โโ libgrpc >=1.59.3,<1.60.0a0 , which requires
โ โ โโ libprotobuf >=4.24.4,<4.24.5.0a0 , which conflicts with any installable versions previously reported;
โ โโ jaxlib 0.4.23 would require
โ โโ python_abi 3.12.* *_cp312, which requires
โ โโ python 3.12.* *_cpython, which can be installed;
โโ python 3.11** is not installable because it conflicts with any installable versions previously reported;
โโ pytorch 2.1.2 *cuda* is installable with the potential options
โโ pytorch 2.1.2 would require
โ โโ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
โ โโ libtorch 2.1.2.* with the potential options
โ โโ libtorch 2.1.2 would require
โ โ โโ pytorch 2.1.2 cpu_generic_*_0, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ pytorch 2.1.2 cpu_generic_*_1, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ pytorch 2.1.2 cpu_generic_*_3, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ pytorch 2.1.2 cpu_mkl_*_100, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ pytorch 2.1.2 cpu_mkl_*_101, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ pytorch 2.1.2 cpu_mkl_*_103, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ pytorch 2.1.2 cuda112_*_300, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
โ โ โโ pytorch 2.1.2 cuda112_*_301, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
โ โ โโ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
โ โ โโ pytorch 2.1.2 cuda118_*_301, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
โ โ โโ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
โ โโ libtorch 2.1.2 would require
โ โ โโ pytorch 2.1.2 cuda118_*_300, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ cuda-version >=12.0,<13 , which can be installed (as previously explained);
โ โ โโ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
โ โ โโ pytorch 2.1.2 cuda120_*_301, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โ โโ cuda-version >=12.0,<13 , which can be installed (as previously explained);
โ โ โโ pytorch 2.1.2 cuda120_*_303, which can be installed;
โ โโ libtorch 2.1.2 would require
โ โโ pytorch 2.1.2 cuda120_*_300, which can be installed;
โโ pytorch 2.1.2 would require
โ โโ libprotobuf >=4.24.4,<4.24.5.0a0 , which conflicts with any installable versions previously reported;
โโ pytorch 2.1.2 would require
โ โโ python_abi 3.12.* *_cp312, which can be installed (as previously explained);
โโ pytorch 2.1.2 would require
โโ cuda-version >=12.0,<13 , which can be installed (as previously explained);
โโ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported.
Once a conda-forge pytorch version gets compiled with libprotobuf==4.25.3 (i.e. conda-forge/pytorch-cpu-feedstock#228 is ready and merged, big thanks to who the pytorch and jax conda-forge mantainers) it should be possible to install both jax and pytorch with cuda enabled and using cuda 12 just with conda-forge packages.
JAX 0.4.26 relaxed our CUDA version dependencies so the minimum CUDA version for JAX is 12.1. This is a version also supported by PyTorch. Try it out! We're going to try to make sure our supported version range overlaps with at least one PyTorch release.
We dropped support for CUDA 11, note.
The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:
FYI, at the moment it is not possible to get both jax and pytorch with cuda 12 only using conda-forge dependencies for this reason (I pinned several dependencies to get a clearer error):
After a bunch of fixes from both jax and pytorch mantainers, now (late May 2024) it is possible to just install jax and pytorch from conda-forge on Linux and out of the box they will work with GPU/CUDA support without the need to use any other conda channel:
$ conda create -c conda-forge -n jaxpytorch pytorch jax
$ conda activate jaxpytorch
$ python
Python 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>>
If for some reason this command does not install the cuda-enabled jax, perhaps you are still using the classic conda solver, in that case you can force the installation of cuda-enabled jax and pytorch with:
conda create -n conda-forge -n jaxpytorch pytorch=*=cuda* jax jaxlib=*=cuda*
However, this is not necessary if you are using a recent conda install that defaults to use the conda-libmamba-solver
, see https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community .
conda list for reference
(jaxpytorch) traversaro@IITBMP014LW012:~$ conda list
# packages in environment at /home/traversaro/miniforge3/envs/jaxpytorch:
#
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
_sysroot_linux-64_curr_repodata_hack 3 h69a702a_14 conda-forge
binutils_impl_linux-64 2.40 ha1999f0_1 conda-forge
binutils_linux-64 2.40 hdade7a5_3 conda-forge
bzip2 1.0.8 hd590300_5 conda-forge
c-ares 1.28.1 hd590300_0 conda-forge
ca-certificates 2024.2.2 hbcca054_0 conda-forge
cuda-cccl_linux-64 12.5.39 ha770c72_0 conda-forge
cuda-crt-dev_linux-64 12.5.40 ha770c72_0 conda-forge
cuda-crt-tools 12.5.40 ha770c72_0 conda-forge
cuda-cudart 12.5.39 he02047a_0 conda-forge
cuda-cudart-dev 12.5.39 he02047a_0 conda-forge
cuda-cudart-dev_linux-64 12.5.39 h85509e4_0 conda-forge
cuda-cudart-static 12.5.39 he02047a_0 conda-forge
cuda-cudart-static_linux-64 12.5.39 h85509e4_0 conda-forge
cuda-cudart_linux-64 12.5.39 h85509e4_0 conda-forge
cuda-cupti 12.5.39 he02047a_0 conda-forge
cuda-driver-dev_linux-64 12.5.39 h85509e4_0 conda-forge
cuda-nvcc 12.5.40 hcdd1206_0 conda-forge
cuda-nvcc-dev_linux-64 12.5.40 ha770c72_0 conda-forge
cuda-nvcc-impl 12.5.40 hd3aeb46_0 conda-forge
cuda-nvcc-tools 12.5.40 hd3aeb46_0 conda-forge
cuda-nvcc_linux-64 12.5.40 h8a487aa_0 conda-forge
cuda-nvrtc 12.5.40 he02047a_0 conda-forge
cuda-nvtx 12.5.39 he02047a_0 conda-forge
cuda-nvvm-dev_linux-64 12.5.40 ha770c72_0 conda-forge
cuda-nvvm-impl 12.5.40 h59595ed_0 conda-forge
cuda-nvvm-tools 12.5.40 h59595ed_0 conda-forge
cuda-version 12.5 hd4f0392_3 conda-forge
cudnn 8.9.7.29 h092f7fd_3 conda-forge
filelock 3.14.0 pyhd8ed1ab_0 conda-forge
fsspec 2024.5.0 pyhff2d567_0 conda-forge
gcc_impl_linux-64 12.3.0 h58ffeeb_7 conda-forge
gcc_linux-64 12.3.0 h6477408_3 conda-forge
gmp 6.3.0 h59595ed_1 conda-forge
gmpy2 2.1.5 py312h1d5cde6_1 conda-forge
gxx_impl_linux-64 12.3.0 h2a574ab_7 conda-forge
gxx_linux-64 12.3.0 h4a1b8e8_3 conda-forge
icu 73.2 h59595ed_0 conda-forge
importlib-metadata 7.1.0 pyha770c72_0 conda-forge
importlib_metadata 7.1.0 hd8ed1ab_0 conda-forge
jax 0.4.27 pyhd8ed1ab_0 conda-forge
jaxlib 0.4.23 cuda120py312h6027bbc_202 conda-forge
jinja2 3.1.4 pyhd8ed1ab_0 conda-forge
kernel-headers_linux-64 3.10.0 h4a8ded7_14 conda-forge
ld_impl_linux-64 2.40 hf3520f5_1 conda-forge
libabseil 20240116.2 cxx17_h59595ed_0 conda-forge
libblas 3.9.0 22_linux64_openblas conda-forge
libcblas 3.9.0 22_linux64_openblas conda-forge
libcublas 12.5.2.13 he02047a_0 conda-forge
libcufft 11.2.3.18 he02047a_0 conda-forge
libcurand 10.3.6.39 he02047a_0 conda-forge
libcusolver 11.6.2.40 he02047a_0 conda-forge
libcusparse 12.4.1.24 he02047a_0 conda-forge
libexpat 2.6.2 h59595ed_0 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc-devel_linux-64 12.3.0 h0223996_107 conda-forge
libgcc-ng 13.2.0 h77fa898_7 conda-forge
libgfortran-ng 13.2.0 h69a702a_7 conda-forge
libgfortran5 13.2.0 hca663fb_7 conda-forge
libgomp 13.2.0 h77fa898_7 conda-forge
libgrpc 1.62.2 h15f2491_0 conda-forge
libhwloc 2.10.0 default_h5622ce7_1001 conda-forge
libiconv 1.17 hd590300_2 conda-forge
liblapack 3.9.0 22_linux64_openblas conda-forge
libmagma 2.7.2 h173bb3b_2 conda-forge
libmagma_sparse 2.7.2 h173bb3b_3 conda-forge
libnsl 2.0.1 hd590300_0 conda-forge
libnvjitlink 12.5.40 he02047a_0 conda-forge
libopenblas 0.3.27 pthreads_h413a1c8_0 conda-forge
libprotobuf 4.25.3 h08a7969_0 conda-forge
libre2-11 2023.09.01 h5a48ba9_2 conda-forge
libsanitizer 12.3.0 hb8811af_7 conda-forge
libsqlite 3.45.3 h2797004_0 conda-forge
libstdcxx-devel_linux-64 12.3.0 h0223996_107 conda-forge
libstdcxx-ng 13.2.0 hc0a3c3a_7 conda-forge
libtorch 2.3.0 cuda120_h2b0da52_301 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libuv 1.48.0 hd590300_0 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libxml2 2.12.7 hc051c1a_0 conda-forge
libzlib 1.2.13 hd590300_5 conda-forge
llvm-openmp 18.1.6 ha31de31_0 conda-forge
markupsafe 2.1.5 py312h98912ed_0 conda-forge
mkl 2023.2.0 h84fe81f_50496 conda-forge
ml_dtypes 0.4.0 py312h1d6d2e6_1 conda-forge
mpc 1.3.1 hfe3b2da_0 conda-forge
mpfr 4.2.1 h9458935_1 conda-forge
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
nccl 2.21.5.1 h3a97aeb_0 conda-forge
ncurses 6.5 h59595ed_0 conda-forge
networkx 3.3 pyhd8ed1ab_1 conda-forge
numpy 1.26.4 py312heda63a1_0 conda-forge
openssl 3.3.0 h4ab18f5_3 conda-forge
opt-einsum 3.3.0 hd8ed1ab_2 conda-forge
opt_einsum 3.3.0 pyhc1e730c_2 conda-forge
pip 24.0 pyhd8ed1ab_0 conda-forge
python 3.12.3 hab00c5b_0_cpython conda-forge
python_abi 3.12 4_cp312 conda-forge
pytorch 2.3.0 cuda120_py312h26b3cf7_301 conda-forge
re2 2023.09.01 h7f4b329_2 conda-forge
readline 8.2 h8228510_1 conda-forge
scipy 1.13.1 py312hc2bc53b_0 conda-forge
setuptools 70.0.0 pyhd8ed1ab_0 conda-forge
sleef 3.5.1 h9b69904_2 conda-forge
sympy 1.12 pypyh9d50eac_103 conda-forge
sysroot_linux-64 2.17 h4a8ded7_14 conda-forge
tbb 2021.12.0 h297d8ca_1 conda-forge
tk 8.6.13 noxft_h4845f30_101 conda-forge
typing_extensions 4.11.0 pyha770c72_0 conda-forge
tzdata 2024a h0c530f3_0 conda-forge
wheel 0.43.0 pyhd8ed1ab_1 conda-forge
xz 5.2.6 h166bdaf_0 conda-forge
zipp 3.17.0 pyhd8ed1ab_0 conda-forge
zstd 1.5.6 ha6fb4c9_0 conda-forge
Can someone please point out the correct version necessary to get pytorch and jax both with GPU support on CUDA 12 as of July 2024?
I would prefer it to be a standard venv rather than a conda env, but either is fine.
@varadVaidya totally by chance I follow this issue, but in general you may have more success in finding help by using official jax help channels (see https://jax.readthedocs.io/en/latest/beginner_guide.html#finding-help), rather then posting in closed issues.
More on topic, I have no idea about pip/venv with cuda, but for conda the procedure posted in #18032 (comment) is working fine for me (when I originally posted the message I forgot to add the -c conda-forge
to ensure it works fine also on anaconda
or miniconda
installation of conda that use defaults
instead of conda-forge
, I just fixed that to avoid confusion).
@eliseoe @bebark @shaikalthaf4 By change I just noticed that you added a ๐๐ฝ reaction to my previous comment, any reason for doing so? Just fyi, authors do not get (at least by default) notifications for post reactions.
@traversaro I found that running your command with conda
will install:
jaxlib conda-forge/linux-64::jaxlib-0.4.27-cpu_py312h17e8b90_0
whereas with mamba
the correct version is installed:
mamba create -c conda-forge -n jaxpytorch pytorch jax
jaxlib 0.4.27 cuda120py312h4008524_200 conda-forge/linux-64
Perhaps this is why you got 3 thumbs down
@traversaro I found that running your command with
conda
will install:jaxlib conda-forge/linux-64::jaxlib-0.4.27-cpu_py312h17e8b90_0
whereas withmamba
the correct version is installed:mamba create -c conda-forge -n jaxpytorch pytorch jax
jaxlib 0.4.27 cuda120py312h4008524_200 conda-forge/linux-64
Perhaps this is why you got 3 thumbs down
Interestingly, in my system with:
root@DESKTOP-T0NQNLN:~# conda info
active environment : None
shell level : 0
user config file : /root/.condarc
populated config files : /root/miniforge3/.condarc
/root/.condarc
conda version : 24.3.0
conda-build version : not installed
python version : 3.10.14.final.0
solver : libmamba (default)
virtual packages : __archspec=1=skylake
__conda=24.3.0=0
__cuda=12.0=0
__glibc=2.39=0
__linux=5.15.153.1=0
__unix=0=0
base environment : /root/miniforge3 (writable)
conda av data dir : /root/miniforge3/etc/conda
conda av metadata url : None
channel URLs : https://conda.anaconda.org/conda-forge/linux-64
https://conda.anaconda.org/conda-forge/noarch
package cache : /root/miniforge3/pkgs
/root/.conda/pkgs
envs directories : /root/miniforge3/envs
/root/.conda/envs
platform : linux-64
user-agent : conda/24.3.0 requests/2.31.0 CPython/3.10.14 Linux/5.15.153.1-microsoft-standard-WSL2 ubuntu/24.04 glibc/2.39 solver/libmamba conda-libmamba-solver/24.1.0 libmambapy/1.5.8
UID:GID : 0:0
netrc file : None
offline mode : False
the command
conda create -n conda-forge -n jaxpytorch pytorch jax
installs the cuda jax, but indeed:
conda create --solver=classic -n conda-forge -n jaxpytorch pytorch jax
installs cpu jax. Perhaps you are using an old conda version that is using the classic solver by default? (You can see this if you report the conda info
output, see https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community).
However, even with the classic solver forcing the solver to install the cuda version of jaxlib and pytorch works as expected (even if the classic solver is much slower):
conda create --solver=classic -n conda-forge -n jaxpytorch pytorch=*=cuda* jax jaxlib=*=cuda*
I edited the original comment accordingly.
You are right, I'm using the classic solver:
active environment : base
active env location : /home/chovanec/miniconda3
shell level : 1
user config file : /home/chovanec/.condarc
populated config files : /home/chovanec/.condarc
conda version : 23.1.0
conda-build version : not installed
python version : 3.10.8.final.0
virtual packages : __archspec=1=x86_64
__cuda=12.3=0
__glibc=2.31=0
__linux=5.15.153.1=0
__unix=0=0
base environment : /home/chovanec/miniconda3 (writable)
conda av data dir : /home/chovanec/miniconda3/etc/conda
conda av metadata url : None
channel URLs : https://conda.anaconda.org/bioconda/linux-64
https://conda.anaconda.org/bioconda/noarch
https://conda.anaconda.org/conda-forge/linux-64
https://conda.anaconda.org/conda-forge/noarch
https://repo.anaconda.com/pkgs/main/linux-64
https://repo.anaconda.com/pkgs/main/noarch
https://repo.anaconda.com/pkgs/r/linux-64
https://repo.anaconda.com/pkgs/r/noarch
package cache : /home/chovanec/miniconda3/pkgs
/home/chovanec/.conda/pkgs
envs directories : /home/chovanec/miniconda3/envs
/home/chovanec/.conda/envs
platform : linux-64
user-agent : conda/23.1.0 requests/2.28.1 CPython/3.10.8 Linux/5.15.153.1-microsoft-standard-WSL2 ubuntu/20.04.6 glibc/2.31
UID:GID : 1000:1000
netrc file : None
offline mode : False
The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:
mamba create -n jaxTorch jaxlib pytorch cuda-nvcc -c conda-forge -c nvidia -c pytorch
>>> import torch >>> import jax >>> torch.cuda.is_available() True >>> jax.devices() [cuda(id=0)] >>> import jaxlib.cuda._versions >>> jaxlib.cuda._versions.cudnn_get_version() 8902 >>> torch._C._cudnn.getCompileVersion() (8, 9, 2)conda list
fyi @traversaro
Thanks for the solution, however i have found a possible bug that the jax numpy cannot initialize an array which size is bigger than (2, 52, 10) with both jax and jaxlib version are 0.4.30, so i have to downgrade the jax version to 0.4.23 and then works just fine, so for the insurance, the command could be like
conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch
python 3.12 is too newer to some commonly used pkgs
Just a curiosity, are you actually getting any packages from the nvidia
or pytorch
channel? If conda-forge
channel is used and you are using strict priority, all the packages you get should come from conda-forge
, and so I guess you could drop the -c nvidia -c pytorch
from your command. However, you can check this by calling conda list
and checking from where packages are installed.
Just a curiosity, are you actually getting any packages from the
nvidia
orpytorch
channel? Ifconda-forge
channel is used and you are using strict priority, all the packages you get should come fromconda-forge
, and so I guess you could drop the-c nvidia -c pytorch
from your command. However, you can check this by callingconda list
and checking from where packages are installed.
I'm not sure, maybe later i can do a test,thx for the noticing
Just a curiosity, are you actually getting any packages from the
nvidia
orpytorch
channel? Ifconda-forge
channel is used and you are using strict priority, all the packages you get should come fromconda-forge
, and so I guess you could drop the-c nvidia -c pytorch
from your command. However, you can check this by callingconda list
and checking from where packages are installed.
sorry for the late reply, here is the outputs
since the jax and jax cuda lib are manually reinstalled by the pypi, i guess yes that the packages are privileged installed from conda-forge :)
Not sure how you can can end up with jax/jaxlib installed via pypi if you just created the environment with conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch
, but as a general comment if you are installing something with pip is a good idea not to install it via conda, to avoid conflicts.
Not sure how you can can end up with jax/jaxlib installed via pypi if you just created the environment with
conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch
, but as a general comment if you are installing something with pip is a good idea not to install it via conda, to avoid conflicts.
In my case, the conflicts comes from the torch and jaxlib stick to different cudnn version, formerly i didn't seek to conda-forge to install the cudatoolkit compatible for both torch and jaxlib. i use the pip command from the official jax documentation btw.
Ok, but in that case it is probably a good idea not to install jax
and jaxlib
from conda, and only install it from pip.
Ok, but in that case it is probably a good idea not to install
jax
andjaxlib
from conda, and only install it from pip.
i think the only reason for the jax
and 'jaxlib
suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstall jax
and jaxlib
from pip
Ok, but in that case it is probably a good idea not to install
jax
andjaxlib
from conda, and only install it from pip.i think the only reason for the
jax
and'jaxlib
suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstalljax
andjaxlib
from pip
But conda has no idea which version of cudnn the jaxlib installed via pip requires. If you want to install cudnn (and even a specific version) with conda, just install cudnn, to avoid problems is tipically useful to avoid to install jax or jaxlib via conda if you are installing it via pip.
Ok, but in that case it is probably a good idea not to install
jax
andjaxlib
from conda, and only install it from pip.i think the only reason for the
jax
and'jaxlib
suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstalljax
andjaxlib
from pipBut conda has no idea which version of cudnn the jaxlib installed via pip requires. If you want to install cudnn (and even a specific version) with conda, just install cudnn, to avoid problems is tipically useful to avoid to install jax or jaxlib via conda if you are installing it via pip.
you are right, accidentally i use the pip install, and it just found the cudnn version meets the requirement lol.
@traversaro apologies to have to revive this issue, but your solution does not work for me:
$ conda create -c conda-forge -n jaxpytorch pytorch jax
$ conda activate jaxpytorch
$ python
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 9100
Installed version: 8907
The local installation version must be no lower than 9100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
The error message suggests I need cudnn>=9.1.0
, however the most recent version on conda-forge appears to be 8.9.7.29
.
System info:
Operating System: Ubuntu 22.04.4 LTS
GPU: NVIDIA GeForce GTX 1060 6GB
Graphics Driver: NVIDIA driver metapackage from nvidia-driver-535
mamba list
below the fold:
# packages in environment at /home/lucas/mambaforge/envs/jaxpytorch:
#
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
_sysroot_linux-64_curr_repodata_hack 3 h69a702a_16 conda-forge
binutils_impl_linux-64 2.40 ha1999f0_7 conda-forge
binutils_linux-64 2.40 hb3c18ed_0 conda-forge
bzip2 1.0.8 h4bc722e_7 conda-forge
c-ares 1.32.3 h4bc722e_0 conda-forge
ca-certificates 2024.7.4 hbcca054_0 conda-forge
cuda-cccl_linux-64 12.5.39 ha770c72_0 conda-forge
cuda-crt-dev_linux-64 12.5.82 ha770c72_0 conda-forge
cuda-crt-tools 12.5.82 ha770c72_0 conda-forge
cuda-cudart 12.5.82 he02047a_0 conda-forge
cuda-cudart-dev 12.5.82 he02047a_0 conda-forge
cuda-cudart-dev_linux-64 12.5.82 h85509e4_0 conda-forge
cuda-cudart-static 12.5.82 he02047a_0 conda-forge
cuda-cudart-static_linux-64 12.5.82 h85509e4_0 conda-forge
cuda-cudart_linux-64 12.5.82 h85509e4_0 conda-forge
cuda-cupti 12.5.82 he02047a_0 conda-forge
cuda-driver-dev_linux-64 12.5.82 h85509e4_0 conda-forge
cuda-nvcc 12.5.82 hcdd1206_0 conda-forge
cuda-nvcc-dev_linux-64 12.5.82 ha770c72_0 conda-forge
cuda-nvcc-impl 12.5.82 hd3aeb46_0 conda-forge
cuda-nvcc-tools 12.5.82 hd3aeb46_0 conda-forge
cuda-nvcc_linux-64 12.5.82 h8a487aa_0 conda-forge
cuda-nvrtc 12.5.82 he02047a_0 conda-forge
cuda-nvtx 12.5.82 he02047a_0 conda-forge
cuda-nvvm-dev_linux-64 12.5.82 ha770c72_0 conda-forge
cuda-nvvm-impl 12.5.82 h59595ed_0 conda-forge
cuda-nvvm-tools 12.5.82 h59595ed_0 conda-forge
cuda-version 12.5 hd4f0392_3 conda-forge
cudnn 8.9.7.29 h092f7fd_3 conda-forge
filelock 3.15.4 pyhd8ed1ab_0 conda-forge
fsspec 2024.6.1 pyhff2d567_0 conda-forge
gcc_impl_linux-64 13.3.0 hfea6d02_0 conda-forge
gcc_linux-64 13.3.0 hc28eda2_0 conda-forge
gmp 6.3.0 hac33072_2 conda-forge
gmpy2 2.1.5 py312h1d5cde6_1 conda-forge
gxx_impl_linux-64 13.3.0 hffce095_0 conda-forge
gxx_linux-64 13.3.0 h6834431_0 conda-forge
icu 75.1 he02047a_0 conda-forge
importlib-metadata 8.2.0 pyha770c72_0 conda-forge
importlib_metadata 8.2.0 hd8ed1ab_0 conda-forge
jax 0.4.31 pyhd8ed1ab_0 conda-forge
jaxlib 0.4.30 cuda120py312h4008524_200 conda-forge
jinja2 3.1.4 pyhd8ed1ab_0 conda-forge
kernel-headers_linux-64 3.10.0 h4a8ded7_16 conda-forge
ld_impl_linux-64 2.40 hf3520f5_7 conda-forge
libabseil 20240116.2 cxx17_he02047a_1 conda-forge
libblas 3.9.0 23_linux64_openblas conda-forge
libcblas 3.9.0 23_linux64_openblas conda-forge
libcublas 12.5.3.2 he02047a_0 conda-forge
libcufft 11.2.3.61 he02047a_0 conda-forge
libcurand 10.3.6.82 he02047a_0 conda-forge
libcusolver 11.6.3.83 he02047a_0 conda-forge
libcusparse 12.5.1.3 he02047a_0 conda-forge
libexpat 2.6.2 h59595ed_0 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc-devel_linux-64 13.3.0 h84ea5a7_100 conda-forge
libgcc-ng 14.1.0 h77fa898_0 conda-forge
libgfortran-ng 14.1.0 h69a702a_0 conda-forge
libgfortran5 14.1.0 hc5f4f2c_0 conda-forge
libgomp 14.1.0 h77fa898_0 conda-forge
libgrpc 1.62.2 h15f2491_0 conda-forge
libhwloc 2.11.1 default_hecaa2ac_1000 conda-forge
libiconv 1.17 hd590300_2 conda-forge
liblapack 3.9.0 23_linux64_openblas conda-forge
libmagma 2.7.2 h173bb3b_2 conda-forge
libmagma_sparse 2.7.2 h173bb3b_3 conda-forge
libnsl 2.0.1 hd590300_0 conda-forge
libnvjitlink 12.5.82 he02047a_0 conda-forge
libopenblas 0.3.27 pthreads_hac2b453_1 conda-forge
libprotobuf 4.25.3 h08a7969_0 conda-forge
libre2-11 2023.09.01 h5a48ba9_2 conda-forge
libsanitizer 13.3.0 heb74ff8_0 conda-forge
libsqlite 3.46.0 hde9e2c9_0 conda-forge
libstdcxx-devel_linux-64 13.3.0 h84ea5a7_100 conda-forge
libstdcxx-ng 14.1.0 hc0a3c3a_0 conda-forge
libtorch 2.3.1 cuda120_h2b0da52_300 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libuv 1.48.0 hd590300_0 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libxml2 2.12.7 he7c6b58_4 conda-forge
libzlib 1.3.1 h4ab18f5_1 conda-forge
llvm-openmp 18.1.8 hf5423f3_0 conda-forge
markupsafe 2.1.5 py312h98912ed_0 conda-forge
mkl 2023.2.0 h84fe81f_50496 conda-forge
ml_dtypes 0.4.0 py312h1d6d2e6_1 conda-forge
mpc 1.3.1 hfe3b2da_0 conda-forge
mpfr 4.2.1 h38ae2d0_2 conda-forge
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
nccl 2.22.3.1 hbc370b7_1 conda-forge
ncurses 6.5 h59595ed_0 conda-forge
networkx 3.3 pyhd8ed1ab_1 conda-forge
numpy 2.0.1 py312h1103770_0 conda-forge
openssl 3.3.1 h4bc722e_2 conda-forge
opt-einsum 3.3.0 hd8ed1ab_2 conda-forge
opt_einsum 3.3.0 pyhc1e730c_2 conda-forge
pip 24.0 pyhd8ed1ab_0 conda-forge
python 3.12.4 h194c7f8_0_cpython conda-forge
python_abi 3.12 4_cp312 conda-forge
pytorch 2.3.1 cuda120_py312h26b3cf7_300 conda-forge
re2 2023.09.01 h7f4b329_2 conda-forge
readline 8.2 h8228510_1 conda-forge
scipy 1.14.0 py312hc2bc53b_1 conda-forge
setuptools 71.0.4 pyhd8ed1ab_0 conda-forge
sleef 3.6.1 h3400bea_1 conda-forge
sympy 1.13.0 pypyh2585a3b_103 conda-forge
sysroot_linux-64 2.17 h4a8ded7_16 conda-forge
tbb 2021.12.0 h434a139_3 conda-forge
tk 8.6.13 noxft_h4845f30_101 conda-forge
typing_extensions 4.12.2 pyha770c72_0 conda-forge
tzdata 2024a h0c530f3_0 conda-forge
wheel 0.43.0 pyhd8ed1ab_1 conda-forge
xz 5.2.6 h166bdaf_0 conda-forge
zipp 3.19.2 pyhd8ed1ab_0 conda-forge
zstd 1.5.6 ha6fb4c9_0 conda-forge
UPDATE: jax=0.4.28
appears to does work, so this looks like a bug introduced recently in JAX. cc @hawkinsp
@lucascolley thanks for reporting the issue, can you please open an issue in https://github.com/conda-forge/jaxlib-feedstock and tag me there? Thanks!
@lucascolley thanks for reporting the issue, can you please open an issue in https://github.com/conda-forge/jaxlib-feedstock and tag me there? Thanks!
Thanks @lucascolley, indeed it seems a regression in the conda-forge jax package 0.4.31, I opened conda-forge/jaxlib-feedstock#277 to track the problem.