google/jax

JAX and TORCH

Closed this issue ยท 33 comments

ywsslr commented

Description

When I only pip the latesd jax with cuda(pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html), I can use the jax with gpu.But when I pip the torch(pip install torch) later, Ican't use the jax with gpu,it remind me that cuda or cusolver's version is older than jax's.Why? Can Older jax version avoid it?Then how can I pip the jax[cuda] with relevant version?

What jax/jaxlib version are you using?

jax-0.4.18 jaxlib-0.4.18+cuda12.cudnn89

Which accelerator(s) are you using?

GPU

Additional system info

3.10.9/Linux

NVIDIA GPU info

No response

That's correct. The current releases of PyTorch and JAX have incompatible CUDA version dependencies.

I reported this issue to the PyTorch developers a while back, but there has been no interest in relaxing their CUDA version dependencies.

My recommendations:

  • use a different virtualenv for PyTorch and JAX. This is the simplest solution and probably the best.
  • if for some reason you really want PyTorch and JAX in the same virtualenv, install the CPU version of one of them and the CUDA version of the other. That avoids any CUDA version conflicts.
  • it may work to simply install JAX after PyTorch, since JAX wants a newer CUDA version than PyTorch's current release does, and in practice NVIDIA's CUDA releases are backwards compatible. I'm not sure if PyTorch enforces a version check, but if not it's highly likely this will work. But I don't think the PyTorch developers support it.
  • another solution is to install the CUDA version needed for one of the two, and build the other one from source. For example, JAX will happily build from source with an older CUDA release, it's just the binary distribution that requires a CUDA version matching the version against which it was built.

Does that resolve your problem?

Hope that helps!

This is quite annoying (and inconvenient) now that people have written torch2jax functionality which allows GPU-accelerated interaction,

https://github.com/samuela/torch2jax
https://github.com/rdyro/torch2jax

Hi @ywsslr, I've been experimenting the simultaneous usage of Torch and JAX for a while. I'm currently working in a Docker container in which they both work on GPU.

JAX was installed according to the official documentation as:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

I leave here the Conda YAML of the environment, there will probably be some extra packages, but I hope this can help:

conda environment
name: base
channels:
  - nvidia
  - conda-forge
  - defaults
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - absl-py=1.4.0=py310h06a4308_0
  - alsa-lib=1.2.10=hd590300_0
  - appdirs=1.4.4=pyhd3eb1b0_0
  - asttokens=2.0.5=pyhd3eb1b0_0
  - attr=2.5.1=h166bdaf_1
  - backcall=0.2.0=pyhd3eb1b0_0
  - binutils=2.40=hdd6e379_0
  - binutils_impl_linux-64=2.40=hf600244_0
  - binutils_linux-64=2.40=hbdbef99_2
  - blas=1.0=openblas
  - boltons=23.0.0=pyhd8ed1ab_0
  - brotli=1.0.9=he6710b0_2
  - brotli-python=1.1.0=py310hc6cd4ac_0
  - bzip2=1.0.8=h7f98852_4
  - c-ares=1.19.1=hd590300_0
  - c-compiler=1.6.0=hd590300_0
  - ca-certificates=2023.7.22=hbcca054_0
  - cairo=1.16.0=hb05425b_5
  - certifi=2023.7.22=pyhd8ed1ab_0
  - cffi=1.15.1=py310h255011f_3
  - charset-normalizer=3.2.0=pyhd8ed1ab_0
  - chex=0.1.5=py310h06a4308_0
  - click=8.0.4=py310h06a4308_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - coloredlogs=15.0.1=py310h06a4308_1
  - compilers=1.6.0=ha770c72_0
  - conda=23.3.1=py310hff52083_0
  - conda-package-handling=2.2.0=pyh38be061_0
  - conda-package-streaming=0.9.0=pyhd8ed1ab_0
  - contourpy=1.0.5=py310hdb19cb5_0
  - cryptography=41.0.3=py310h75e40e8_0
  - cuda-nvcc=11.3.58=h2467b9f_0
  - cuda-version=11.8=h70ddcb2_2
  - cudatoolkit=11.8.0=h4ba93d1_12
  - cudnn=8.9.2.26=cuda11_0
  - cupti=11.8.0=he078b1a_0
  - cxx-compiler=1.6.0=h00ab1b0_0
  - cycler=0.11.0=pyhd3eb1b0_0
  - dbus=1.13.18=hb2f20db_0
  - deap=1.4.1=py310h7cbd5c2_0
  - decorator=5.1.1=pyhd3eb1b0_0
  - dm-tree=0.1.7=py310h6a678d5_1
  - docker-pycreds=0.4.0=pyhd3eb1b0_0
  - docstring_parser=0.15=pyhd8ed1ab_0
  - exceptiongroup=1.0.4=py310h06a4308_0
  - executing=0.8.3=pyhd3eb1b0_0
  - expat=2.5.0=h6a678d5_0
  - filelock=3.9.0=py310h06a4308_0
  - flax=0.6.1=pyhd8ed1ab_1
  - fmt=9.1.0=h924138e_0
  - font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
  - font-ttf-inconsolata=2.001=hcb22688_0
  - font-ttf-source-code-pro=2.030=hd3eb1b0_0
  - font-ttf-ubuntu=0.83=h8b1ccd4_0
  - fontconfig=2.14.2=h14ed4e7_0
  - fonts-anaconda=1=h8fa9717_0
  - fonts-conda-ecosystem=1=hd3eb1b0_0
  - fonttools=4.25.0=pyhd3eb1b0_0
  - fortran-compiler=1.6.0=heb67821_0
  - freetype=2.12.1=h4a9f257_0
  - frozendict=2.3.8=py310h2372a71_0
  - gcc=12.3.0=h8d2909c_2
  - gcc_impl_linux-64=12.3.0=he2b93b0_1
  - gcc_linux-64=12.3.0=h76fc315_2
  - gettext=0.21.1=h27087fc_0
  - gfortran=12.3.0=h499e0f7_2
  - gfortran_impl_linux-64=12.3.0=hfcedea8_1
  - gfortran_linux-64=12.3.0=h7fe76b4_2
  - gitdb=4.0.7=pyhd3eb1b0_0
  - gitpython=3.1.30=py310h06a4308_0
  - glib=2.78.0=hfc55251_0
  - glib-tools=2.78.0=hfc55251_0
  - gmp=6.2.1=h295c915_3
  - gmpy2=2.1.2=py310heeb90bb_0
  - graphite2=1.3.14=h295c915_1
  - gst-plugins-base=1.22.5=h8e1006c_1
  - gstreamer=1.22.5=h98fc4e7_1
  - gxx=12.3.0=h8d2909c_2
  - gxx_impl_linux-64=12.3.0=he2b93b0_1
  - gxx_linux-64=12.3.0=h8a814eb_2
  - harfbuzz=8.2.0=h3d44ed6_0
  - humanfriendly=10.0=py310h06a4308_1
  - icu=73.2=h59595ed_0
  - idna=3.4=pyhd8ed1ab_0
  - intel-openmp=2023.1.0=hdb19cb5_46305
  - ipython=8.15.0=py310h06a4308_0
  - jax-dataclasses=1.5.1=pyhd8ed1ab_0
  - jaxlie=1.3.3=pyhd8ed1ab_0
  - jedi=0.18.1=py310h06a4308_1
  - jinja2=3.1.2=py310h06a4308_0
  - jsonpatch=1.32=pyhd8ed1ab_0
  - jsonpointer=2.4=py310hff52083_0
  - kernel-headers_linux-64=2.6.32=he073ed8_16
  - keyutils=1.6.1=h166bdaf_0
  - kiwisolver=1.4.4=py310h6a678d5_0
  - krb5=1.21.2=h659d440_0
  - lame=3.100=h7b6447c_0
  - lcms2=2.15=h7f713cb_2
  - ld_impl_linux-64=2.40=h41732ed_0
  - lerc=4.0.0=h27087fc_0
  - libarchive=3.6.2=h039dbb9_1
  - libcap=2.69=h0f662aa_0
  - libclang=15.0.7=default_h7634d5b_3
  - libclang13=15.0.7=default_h9986a30_3
  - libcups=2.3.3=h4637d8d_4
  - libcurl=8.3.0=hca28451_0
  - libdeflate=1.19=hd590300_0
  - libedit=3.1.20191231=he28a2e2_2
  - libev=4.33=h516909a_1
  - libevent=2.1.12=hdbd6064_1
  - libexpat=2.5.0=hcb278e6_1
  - libffi=3.4.2=h7f98852_5
  - libflac=1.4.3=h59595ed_0
  - libgcc-devel_linux-64=12.3.0=h8bca6fd_1
  - libgcc-ng=13.2.0=h807b86a_0
  - libgcrypt=1.10.1=h166bdaf_0
  - libgfortran-ng=13.2.0=h69a702a_1
  - libgfortran5=13.2.0=ha4646dd_1
  - libglib=2.78.0=hebfc3b9_0
  - libgomp=13.2.0=h807b86a_0
  - libgpg-error=1.47=h71f35ed_0
  - libiconv=1.17=h166bdaf_0
  - libjpeg-turbo=2.1.5.1=hd590300_1
  - libllvm15=15.0.7=h5cf9203_3
  - libmamba=1.2.0=hcea66bb_0
  - libmambapy=1.2.0=py310h1428755_0
  - libnghttp2=1.52.0=h61bc06f_0
  - libnsl=2.0.0=h7f98852_0
  - libogg=1.3.5=h27cfd23_1
  - libopenblas=0.3.21=h043d6bf_0
  - libopus=1.3.1=h7b6447c_0
  - libpng=1.6.39=h5eee18b_0
  - libpq=15.4=hfc447b1_0
  - libprotobuf=3.20.3=he621ea3_0
  - libsanitizer=12.3.0=h0f45ef3_1
  - libsndfile=1.2.2=hbc2eb40_0
  - libsolv=0.7.24=hfc55251_4
  - libsqlite=3.43.0=h2797004_0
  - libssh2=1.11.0=h0841786_0
  - libstdcxx-devel_linux-64=12.3.0=h8bca6fd_1
  - libstdcxx-ng=13.2.0=h7e041cc_0
  - libsystemd0=254=h3516f8a_0
  - libtiff=4.6.0=h29866fb_1
  - libuuid=2.38.1=h0b41bf4_0
  - libvorbis=1.3.7=h7b6447c_0
  - libwebp-base=1.3.2=h5eee18b_0
  - libxcb=1.15=h7f8727e_0
  - libxkbcommon=1.5.0=h5d7e998_3
  - libxml2=2.11.5=h232c23b_1
  - libzlib=1.2.13=hd590300_5
  - lz4-c=1.9.4=hcb278e6_0
  - lzo=2.10=h516909a_1000
  - magma=2.7.1=h2c23e93_0
  - mamba=1.2.0=py310h51d5547_0
  - markdown-it-py=2.2.0=py310h06a4308_1
  - markupsafe=2.1.1=py310h7f8727e_0
  - mashumaro=3.6=py310h06a4308_0
  - matplotlib=3.7.2=py310h06a4308_0
  - matplotlib-base=3.7.2=py310h1128e8f_0
  - matplotlib-inline=0.1.6=py310h06a4308_0
  - mdurl=0.1.0=py310h06a4308_0
  - mkl=2023.1.0=h213fc3f_46343
  - mpc=1.1.0=h10f8cd9_1
  - mpfr=4.0.2=hb69a4c5_1
  - mpg123=1.31.3=hcb278e6_0
  - mpmath=1.3.0=py310h06a4308_0
  - msgpack-python=1.0.3=py310hd09550d_0
  - munkres=1.1.4=py_0
  - mysql-common=8.0.33=hf1915f5_4
  - mysql-libs=8.0.33=hca2cd23_4
  - ncurses=6.4=hcb278e6_0
  - networkx=3.1=py310h06a4308_0
  - ninja=1.10.2=h06a4308_5
  - ninja-base=1.10.2=hd09550d_5
  - nspr=4.35=h6a678d5_0
  - nss=3.92=h1d7d5a4_0
  - numpy=1.25.2=py310heeff2f4_0
  - numpy-base=1.25.2=py310h8a23956_0
  - openjpeg=2.5.0=h488ebb8_3
  - openssl=3.1.2=hd590300_0
  - opt_einsum=3.3.0=pyhd3eb1b0_1
  - optax=0.1.4=py310h06a4308_0
  - overrides=7.4.0=pyhd8ed1ab_0
  - packaging=23.1=pyhd8ed1ab_0
  - parso=0.8.3=pyhd3eb1b0_0
  - pathtools=0.1.2=pyhd3eb1b0_1
  - pcre2=10.40=hc3806b6_0
  - pexpect=4.8.0=pyhd3eb1b0_3
  - pickleshare=0.7.5=pyhd3eb1b0_1003
  - pillow=10.0.1=py310h29da1c1_0
  - pip=23.2.1=pyhd8ed1ab_0
  - pixman=0.40.0=h7f8727e_1
  - pluggy=1.3.0=pyhd8ed1ab_0
  - ply=3.11=py310h06a4308_0
  - pptree=3.1=pyhd8ed1ab_0
  - prompt-toolkit=3.0.36=py310h06a4308_0
  - protobuf=3.20.3=py310h6a678d5_0
  - psutil=5.9.0=py310h5eee18b_0
  - ptyprocess=0.7.0=pyhd3eb1b0_2
  - pulseaudio-client=16.1=hb77b528_5
  - pure_eval=0.2.2=pyhd3eb1b0_0
  - pybind11-abi=4=hd8ed1ab_3
  - pycosat=0.6.4=py310h5764c6d_1
  - pycparser=2.21=pyhd8ed1ab_0
  - pygments=2.15.1=py310h06a4308_1
  - pyopenssl=23.2.0=pyhd8ed1ab_1
  - pyparsing=3.0.9=py310h06a4308_0
  - pyqt=5.15.9=py310h04931ad_4
  - pyqt5-sip=12.12.2=py310hc6cd4ac_4
  - pysocks=1.7.1=pyha2e5f31_6
  - python=3.10.8=h4a9ceb5_0_cpython
  - python-dateutil=2.8.2=pyhd3eb1b0_0
  - python_abi=3.10=3_cp310
  - pytorch=2.0.1=gpu_cuda118py310h7799f5a_0
  - pyyaml=6.0=py310h5eee18b_1
  - qt-main=5.15.8=hc47bfe8_16
  - readline=8.2=h8228510_1
  - reproc=14.2.4=h0b41bf4_0
  - reproc-cpp=14.2.4=hcb278e6_0
  - requests=2.31.0=pyhd8ed1ab_0
  - rich=13.3.5=py310h06a4308_0
  - ruamel.yaml=0.17.32=py310h2372a71_0
  - ruamel.yaml.clib=0.2.7=py310h1fa729e_1
  - scipy=1.11.1=py310heeff2f4_0
  - sentry-sdk=1.9.0=py310h06a4308_0
  - setproctitle=1.2.2=py310h7f8727e_0
  - setuptools=68.2.2=pyhd8ed1ab_0
  - shtab=1.6.4=pyhd8ed1ab_1
  - sip=6.7.11=py310hc6cd4ac_0
  - six=1.16.0=pyhd3eb1b0_1
  - smmap=4.0.0=pyhd3eb1b0_0
  - stack_data=0.2.0=pyhd3eb1b0_0
  - sympy=1.11.1=py310h06a4308_0
  - sysroot_linux-64=2.12=he073ed8_16
  - tbb=2021.8.0=hdb19cb5_0
  - tk=8.6.12=h27826a3_0
  - toml=0.10.2=pyhd3eb1b0_0
  - tomli=2.0.1=py310h06a4308_0
  - toolz=0.12.0=pyhd8ed1ab_0
  - tornado=6.3.2=py310h5eee18b_0
  - tqdm=4.66.1=pyhd8ed1ab_0
  - traitlets=5.7.1=py310h06a4308_0
  - typing-extensions=4.7.1=py310h06a4308_0
  - typing_extensions=4.7.1=py310h06a4308_0
  - typing_utils=0.1.0=pyhd8ed1ab_0
  - tyro=0.5.7=pyhd8ed1ab_0
  - tzdata=2023c=h71feb2d_0
  - urllib3=2.0.4=pyhd8ed1ab_0
  - wandb=0.15.10=pyhd8ed1ab_0
  - wcwidth=0.2.5=pyhd3eb1b0_0
  - wheel=0.41.2=pyhd8ed1ab_0
  - xcb-util=0.4.0=hd590300_1
  - xcb-util-image=0.4.0=h8ee46fc_1
  - xcb-util-keysyms=0.4.0=h8ee46fc_1
  - xcb-util-renderutil=0.3.9=hd590300_1
  - xcb-util-wm=0.4.1=h8ee46fc_1
  - xkeyboard-config=2.39=hd590300_0
  - xorg-kbproto=1.0.7=h7f98852_1002
  - xorg-libice=1.1.1=hd590300_0
  - xorg-libsm=1.2.4=h7391055_0
  - xorg-libx11=1.8.6=h8ee46fc_0
  - xorg-libxau=1.0.11=hd590300_0
  - xorg-libxext=1.3.4=h0b41bf4_2
  - xorg-libxrender=0.9.11=hd590300_0
  - xorg-renderproto=0.11.1=h7f98852_1002
  - xorg-xextproto=7.3.0=h0b41bf4_1003
  - xorg-xf86vidmodeproto=2.3.1=h7f98852_1002
  - xorg-xproto=7.0.31=h27cfd23_1007
  - xz=5.2.6=h166bdaf_0
  - yaml=0.2.5=h7b6447c_0
  - yaml-cpp=0.7.0=h27087fc_2
  - zlib=1.2.13=hd590300_5
  - zstandard=0.19.0=py310h5764c6d_0
  - zstd=1.5.5=hfc55251_0
  - pip:
      - jax==0.4.18
      - jaxlib==0.4.18+cuda12.cudnn89
      - ml-dtypes==0.3.1
      - nvidia-cublas-cu12==12.2.5.6
      - nvidia-cuda-cupti-cu12==12.2.142
      - nvidia-cuda-nvcc-cu12==12.2.140
      - nvidia-cuda-nvrtc-cu12==12.2.140
      - nvidia-cuda-runtime-cu12==12.2.140
      - nvidia-cudnn-cu12==8.9.4.25
      - nvidia-cufft-cu12==11.0.8.103
      - nvidia-cusolver-cu12==11.5.2.141
      - nvidia-cusparse-cu12==12.1.2.141
      - nvidia-nccl-cu12==2.18.3
      - nvidia-nvjitlink-cu12==12.2.140
prefix: /conda
  • Python 3.10.8
  • Ubuntu 22.04
  • jax==0.4.18
  • jaxlib==0.4.18+cuda12.cudnn89
  • Driver Version: 525.125.06
  • CUDA Version: 12.0
ywsslr commented

Thank you for your all help. For some reason I can't experience it now,but I'll try it soon and reply you.

ok people, this has been a 1 day nightmare. But finally got this to work on an H100 machine with cuda 12.2, without sudo.

then install pytorch from source as that post says!!!! and bualaaa

No promises, but informally we're going to try to keep at least one JAX release have a version that is also released with PyTorch. Right now, that's the CUDA 11.8 release of JAX.

It's not a guarantee, though; it might happen that for some JAX and Pytorch versions there's no intersecting CUDA version.

I hit a similar issue when installing pytorch and jax into the same conda environment: when torch is loaded first, jax.devices() will list only cpu devices.

A short summary of diagnosis: It turns out that torch is built against cudnn version 8.7 while jaxlib is built against cudnn version 8.8 leading to an exception when executing jax._src.xla_bridge._check_cuda_versions().

Here follows a reproducer:

mamba create -n test-pytorch-jax pytorch pytorch-cuda=11.8 jaxlib=*=*cuda118* jax -c pytorch -c nvidia --no-channel-priority -y
mamba activate test-pytorch-jax

(note: using strict channel priority would lead to a mamba solver problem).

Import torch before checking jax.devices:

>>> import torch
>>> import jax
>>> jax.devices()
CUDA backend failed to initialize: Found cuDNN version 8700, but JAX was built against version 8800, which is newer. The copy of cuDNN that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

Import torch after checking jax.devices:

>>> import jax
>>> jax.devices()
[cuda(id=0), cuda(id=1)]
>>> import torch
>>> jax.__version__
'0.4.23'
>>> torch.__version__
'2.1.2'
>>> from torch._C import _cudnn
>>> _cudnn.getCompileVersion()
(8, 7, 0)

Notices that the result of jaxlib.cuda._versions.cudnn_get_version() depends on whether torch was imported before or after calling jaxlib.cuda._versions.cudnn_get_version:

>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> import torch
>>> jaxlib.cuda._versions.cudnn_get_version()
8902

vs

>>> import torch
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8700

that qualifies as an incompatible linkage issue: since libcudnn is dynamically loaded, the result of cudnnGetVersion ought to give the version of loaded library and not of the version of the library that a software was built against. The behavior above suggests that torch was linked with libcudnn statically.

A possible resolution: Note that cuDNN minor releases are backward compatible with applications built against the same or earlier minor release. Hence, as long as jaxlib and torch are built against libcudnn with the same major version (8), the jax version check ought to ignore cudnn minor versions. Here is a patch:

diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py
index 7977f6329..17c14bc5a 100644
--- a/jax/_src/xla_bridge.py
+++ b/jax/_src/xla_bridge.py
@@ -263,7 +263,7 @@ def _check_cuda_versions():
       cuda_versions.cudnn_build_version,
       # NVIDIA promise both backwards and forwards compatibility for cuDNN patch
       # versions: https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
-      scale_for_comparison=100,
+      scale_for_comparison=1000,
   )
   _version_check("cuFFT", cuda_versions.cufft_get_version,
                  cuda_versions.cufft_build_version,

No promises, but informally we're going to try to keep at least one JAX release have a version that is also released with PyTorch. Right now, that's the CUDA 11.8 release of JAX.

The latest version pair I could find that were compatible with each other were jax[cuda11-pip,cuda11_pip]==0.4.10 and torch==2.2.1+cu118. The main conflict in later versions for jax is for cudnn, which want >8.8, but torch wants ==8.7.

One way to check this would be:

cat > requirements.in <<EOF
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
--extra-index-url=https://download.pytorch.org/whl

jax[cuda11_pip]
torch==2.2.1+cu118
EOF

pip-compile

# Check the contents of requirements.txt.

A workaround that works better for us is to use CUDA 11 with Jax, but CUDA 12 with Torch. So basically jax[cuda11_pip] and torch in our requirements file works for us.

A workaround that works better for us is to use CUDA 11 with Jax, but CUDA 12 with Torch.

How did you get this to work? I'm using conda, but after installing pytorch-cuda=12.1 I get the following error from JAX:

E   RuntimeError: Unable to initialize backend 'cuda': Unable to load CUDA. Is it installed? (set JAX_PLATFORMS='' to automatically choose an available backend)

We did not have to do anything special. Just installed the two packages in a clean env, and both worked.

The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:

mamba create -n jaxTorch jaxlib pytorch cuda-nvcc -c conda-forge -c nvidia -c pytorch
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> torch._C._cudnn.getCompileVersion()
(8, 9, 2)
conda list
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
_sysroot_linux-64_curr_repodata_hack 3                   h69a702a_14    conda-forge
binutils_impl_linux-64    2.40                 hf600244_0    conda-forge
binutils_linux-64         2.40                 hdade7a5_3    conda-forge
blas                      2.116                       mkl    conda-forge
blas-devel                3.9.0            16_linux64_mkl    conda-forge
bzip2                     1.0.8                hd590300_5    conda-forge
c-ares                    1.27.0               hd590300_0    conda-forge
ca-certificates           2024.2.2             hbcca054_0    conda-forge
cuda-cccl_linux-64        12.1.109             ha770c72_0    conda-forge
cuda-cudart               12.1.105             hd3aeb46_0    conda-forge
cuda-cudart-dev           12.1.105             hd3aeb46_0    conda-forge
cuda-cudart-dev_linux-64  12.1.105             h59595ed_0    conda-forge
cuda-cudart-static        12.1.105             hd3aeb46_0    conda-forge
cuda-cudart-static_linux-64 12.1.105             h59595ed_0    conda-forge
cuda-cudart_linux-64      12.1.105             h59595ed_0    conda-forge
cuda-cupti                12.1.105             h59595ed_0    conda-forge
cuda-driver-dev_linux-64  12.1.105             h59595ed_0    conda-forge
cuda-libraries            12.1.0                        0    nvidia
cuda-nvcc                 12.1.105             hcdd1206_1    conda-forge
cuda-nvcc-dev_linux-64    12.1.105             ha770c72_0    conda-forge
cuda-nvcc-impl            12.1.105             hd3aeb46_0    conda-forge
cuda-nvcc-tools           12.1.105             hd3aeb46_0    conda-forge
cuda-nvcc_linux-64        12.1.105             h8a487aa_1    conda-forge
cuda-nvrtc                12.1.105             hd3aeb46_0    conda-forge
cuda-nvtx                 12.1.105             h59595ed_0    conda-forge
cuda-opencl               12.1.105             h59595ed_0    conda-forge
cuda-runtime              12.1.0                        0    nvidia
cuda-version              12.1                 h1d6eff3_3    conda-forge
cudnn                     8.9.7.29             h092f7fd_3    conda-forge
filelock                  3.13.3             pyhd8ed1ab_0    conda-forge
gcc_impl_linux-64         12.3.0               he2b93b0_5    conda-forge
gcc_linux-64              12.3.0               h6477408_3    conda-forge
gxx_impl_linux-64         12.3.0               he2b93b0_5    conda-forge
gxx_linux-64              12.3.0               h4a1b8e8_3    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
importlib-metadata        7.1.0              pyha770c72_0    conda-forge
importlib_metadata        7.1.0                hd8ed1ab_0    conda-forge
jax                       0.4.25             pyhd8ed1ab_0    conda-forge
jaxlib                    0.4.23          cuda120py312hc008a70_200    conda-forge
jinja2                    3.1.3              pyhd8ed1ab_0    conda-forge
kernel-headers_linux-64   3.10.0              h4a8ded7_14    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
libabseil                 20240116.1      cxx17_h59595ed_2    conda-forge
libblas                   3.9.0            16_linux64_mkl    conda-forge
libcblas                  3.9.0            16_linux64_mkl    conda-forge
libcublas                 12.1.0.26                     0    nvidia
libcufft                  11.0.2.4                      0    nvidia
libcufile                 1.6.1.9              hd3aeb46_0    conda-forge
libcurand                 10.3.2.106           hd3aeb46_0    conda-forge
libcusolver               11.4.4.55                     0    nvidia
libcusparse               12.0.2.55                     0    nvidia
libexpat                  2.6.2                h59595ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-devel_linux-64     12.3.0             h8bca6fd_105    conda-forge
libgcc-ng                 13.2.0               h807b86a_5    conda-forge
libgfortran-ng            13.2.0               h69a702a_5    conda-forge
libgfortran5              13.2.0               ha4646dd_5    conda-forge
libgomp                   13.2.0               h807b86a_5    conda-forge
libgrpc                   1.62.1               h15f2491_0    conda-forge
libhwloc                  2.9.3           default_h554bfaf_1009    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
liblapack                 3.9.0            16_linux64_mkl    conda-forge
liblapacke                3.9.0            16_linux64_mkl    conda-forge
libnpp                    12.0.2.50                     0    nvidia
libnsl                    2.0.1                hd590300_0    conda-forge
libnvjitlink              12.1.105             hd3aeb46_0    conda-forge
libnvjpeg                 12.1.1.14                     0    nvidia
libprotobuf               4.25.3               h08a7969_0    conda-forge
libre2-11                 2023.09.01           h5a48ba9_2    conda-forge
libsanitizer              12.3.0               h0f45ef3_5    conda-forge
libsqlite                 3.45.2               h2797004_0    conda-forge
libstdcxx-devel_linux-64  12.3.0             h8bca6fd_105    conda-forge
libstdcxx-ng              13.2.0               h7e041cc_5    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.6               h232c23b_1    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               15.0.7               h0cdce71_0    conda-forge
markupsafe                2.1.5           py312h98912ed_0    conda-forge
mkl                       2022.1.0           h84fe81f_915    conda-forge
mkl-devel                 2022.1.0           ha770c72_916    conda-forge
mkl-include               2022.1.0           h84fe81f_915    conda-forge
ml_dtypes                 0.3.2           py312hfb8ada1_0    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
nccl                      2.20.5.1             h3a97aeb_0    conda-forge
ncurses                   6.4.20240210         h59595ed_0    conda-forge
networkx                  3.2.1              pyhd8ed1ab_0    conda-forge
numpy                     1.26.4          py312heda63a1_0    conda-forge
ocl-icd                   2.3.2                hd590300_1    conda-forge
openssl                   3.2.1                hd590300_1    conda-forge
opt-einsum                3.3.0                hd8ed1ab_2    conda-forge
opt_einsum                3.3.0              pyhc1e730c_2    conda-forge
pip                       24.0               pyhd8ed1ab_0    conda-forge
python                    3.12.2          hab00c5b_0_cpython    conda-forge
python_abi                3.12                    4_cp312    conda-forge
pytorch                   2.2.1           py3.12_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
pyyaml                    6.0.1           py312h98912ed_1    conda-forge
re2                       2023.09.01           h7f4b329_2    conda-forge
readline                  8.2                  h8228510_1    conda-forge
scipy                     1.12.0          py312heda63a1_2    conda-forge
setuptools                69.2.0             pyhd8ed1ab_0    conda-forge
sympy                     1.12               pyh04b8f61_3    conda-forge
sysroot_linux-64          2.17                h4a8ded7_14    conda-forge
tbb                       2021.11.0            h00ab1b0_1    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
typing_extensions         4.10.0             pyha770c72_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
wheel                     0.43.0             pyhd8ed1ab_0    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zipp                      3.17.0             pyhd8ed1ab_0    conda-forge

fyi @traversaro

The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:

FYI, at the moment it is not possible to get both jax and pytorch with cuda 12 only using conda-forge dependencies for this reason (I pinned several dependencies to get a clearer error):

traversaro@IITBMP014LW012:~$ mamba create -n jaxtorchcuda pytorch==2.1.2=*cuda* jaxlib==0.4.23=*cuda* jax cuda-version=12.* python==3.11.* cudatoolkit==12.*

Looking for: ['pytorch==2.1.2[build=*cuda*]', 'jaxlib==0.4.23[build=*cuda*]', 'jax', 'cuda-version=12', 'python=3.11', 'cudatoolkit=12']

conda-forge/linux-64                                        Using cache
conda-forge/noarch                                          Using cache
Could not solve for environment specs
The following packages are incompatible
โ”œโ”€ cuda-version 12**  is installable with the potential options
โ”‚  โ”œโ”€ cuda-version [12.0|12.0.0] would require
โ”‚  โ”‚  โ””โ”€ cudatoolkit 12.0|12.0.* , which can be installed;
โ”‚  โ”œโ”€ cuda-version 12.1 would require
โ”‚  โ”‚  โ””โ”€ cudatoolkit 12.1|12.1.* , which can be installed;
โ”‚  โ”œโ”€ cuda-version 12.2 would require
โ”‚  โ”‚  โ””โ”€ cudatoolkit 12.2|12.2.* , which can be installed;
โ”‚  โ”œโ”€ cuda-version 12.3 would require
โ”‚  โ”‚  โ””โ”€ cudatoolkit 12.3|12.3.* , which can be installed;
โ”‚  โ””โ”€ cuda-version 12.4 would require
โ”‚     โ””โ”€ cudatoolkit 12.4|12.4.* , which can be installed;
โ”œโ”€ cudatoolkit 12**  does not exist (perhaps a typo or a missing channel);
โ”œโ”€ jaxlib 0.4.23 *cuda* is installable with the potential options
โ”‚  โ”œโ”€ jaxlib 0.4.23 would require
โ”‚  โ”‚  โ””โ”€ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
โ”‚  โ”œโ”€ jaxlib 0.4.23 would require
โ”‚  โ”‚  โ””โ”€ libgrpc >=1.62.1,<1.63.0a0 , which requires
โ”‚  โ”‚     โ””โ”€ libprotobuf >=4.25.3,<4.25.4.0a0 , which can be installed;
โ”‚  โ”œโ”€ jaxlib 0.4.23 would require
โ”‚  โ”‚  โ”œโ”€ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
โ”‚  โ”‚  โ””โ”€ libgrpc >=1.59.3,<1.60.0a0 , which requires
โ”‚  โ”‚     โ””โ”€ libprotobuf >=4.24.4,<4.24.5.0a0 , which conflicts with any installable versions previously reported;
โ”‚  โ””โ”€ jaxlib 0.4.23 would require
โ”‚     โ””โ”€ python_abi 3.12.* *_cp312, which requires
โ”‚        โ””โ”€ python 3.12.* *_cpython, which can be installed;
โ”œโ”€ python 3.11**  is not installable because it conflicts with any installable versions previously reported;
โ””โ”€ pytorch 2.1.2 *cuda* is installable with the potential options
   โ”œโ”€ pytorch 2.1.2 would require
   โ”‚  โ”œโ”€ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   โ”‚  โ””โ”€ libtorch 2.1.2.*  with the potential options
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cpu_generic_*_0, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cpu_generic_*_1, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cpu_generic_*_3, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cpu_mkl_*_100, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cpu_mkl_*_101, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cpu_mkl_*_103, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cuda112_*_300, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ”œโ”€ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cuda112_*_301, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ”œโ”€ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
   โ”‚     โ”‚  โ”œโ”€ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cuda118_*_301, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ”œโ”€ cudatoolkit >=11.8,<12 , which conflicts with any installable versions previously reported;
   โ”‚     โ”‚  โ””โ”€ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cuda118_*_300, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ”œโ”€ cuda-version >=12.0,<13 , which can be installed (as previously explained);
   โ”‚     โ”‚  โ”œโ”€ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported;
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cuda120_*_301, which can be installed;
   โ”‚     โ”œโ”€ libtorch 2.1.2 would require
   โ”‚     โ”‚  โ”œโ”€ cuda-version >=12.0,<13 , which can be installed (as previously explained);
   โ”‚     โ”‚  โ””โ”€ pytorch 2.1.2 cuda120_*_303, which can be installed;
   โ”‚     โ””โ”€ libtorch 2.1.2 would require
   โ”‚        โ””โ”€ pytorch 2.1.2 cuda120_*_300, which can be installed;
   โ”œโ”€ pytorch 2.1.2 would require
   โ”‚  โ””โ”€ libprotobuf >=4.24.4,<4.24.5.0a0 , which conflicts with any installable versions previously reported;
   โ”œโ”€ pytorch 2.1.2 would require
   โ”‚  โ””โ”€ python_abi 3.12.* *_cp312, which can be installed (as previously explained);
   โ””โ”€ pytorch 2.1.2 would require
      โ”œโ”€ cuda-version >=12.0,<13 , which can be installed (as previously explained);
      โ””โ”€ libprotobuf >=4.25.1,<4.25.2.0a0 , which conflicts with any installable versions previously reported.

Once a conda-forge pytorch version gets compiled with libprotobuf==4.25.3 (i.e. conda-forge/pytorch-cpu-feedstock#228 is ready and merged, big thanks to who the pytorch and jax conda-forge mantainers) it should be possible to install both jax and pytorch with cuda enabled and using cuda 12 just with conda-forge packages.

JAX 0.4.26 relaxed our CUDA version dependencies so the minimum CUDA version for JAX is 12.1. This is a version also supported by PyTorch. Try it out! We're going to try to make sure our supported version range overlaps with at least one PyTorch release.

We dropped support for CUDA 11, note.

The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:

FYI, at the moment it is not possible to get both jax and pytorch with cuda 12 only using conda-forge dependencies for this reason (I pinned several dependencies to get a clearer error):

After a bunch of fixes from both jax and pytorch mantainers, now (late May 2024) it is possible to just install jax and pytorch from conda-forge on Linux and out of the box they will work with GPU/CUDA support without the need to use any other conda channel:

$ conda create -c conda-forge -n jaxpytorch pytorch jax
$ conda activate jaxpytorch
$ python
Python 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>>

If for some reason this command does not install the cuda-enabled jax, perhaps you are still using the classic conda solver, in that case you can force the installation of cuda-enabled jax and pytorch with:

conda create -n conda-forge -n jaxpytorch pytorch=*=cuda* jax jaxlib=*=cuda*

However, this is not necessary if you are using a recent conda install that defaults to use the conda-libmamba-solver, see https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community .

conda list for reference
(jaxpytorch) traversaro@IITBMP014LW012:~$ conda list
# packages in environment at /home/traversaro/miniforge3/envs/jaxpytorch:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
_sysroot_linux-64_curr_repodata_hack 3                   h69a702a_14    conda-forge
binutils_impl_linux-64    2.40                 ha1999f0_1    conda-forge
binutils_linux-64         2.40                 hdade7a5_3    conda-forge
bzip2                     1.0.8                hd590300_5    conda-forge
c-ares                    1.28.1               hd590300_0    conda-forge
ca-certificates           2024.2.2             hbcca054_0    conda-forge
cuda-cccl_linux-64        12.5.39              ha770c72_0    conda-forge
cuda-crt-dev_linux-64     12.5.40              ha770c72_0    conda-forge
cuda-crt-tools            12.5.40              ha770c72_0    conda-forge
cuda-cudart               12.5.39              he02047a_0    conda-forge
cuda-cudart-dev           12.5.39              he02047a_0    conda-forge
cuda-cudart-dev_linux-64  12.5.39              h85509e4_0    conda-forge
cuda-cudart-static        12.5.39              he02047a_0    conda-forge
cuda-cudart-static_linux-64 12.5.39              h85509e4_0    conda-forge
cuda-cudart_linux-64      12.5.39              h85509e4_0    conda-forge
cuda-cupti                12.5.39              he02047a_0    conda-forge
cuda-driver-dev_linux-64  12.5.39              h85509e4_0    conda-forge
cuda-nvcc                 12.5.40              hcdd1206_0    conda-forge
cuda-nvcc-dev_linux-64    12.5.40              ha770c72_0    conda-forge
cuda-nvcc-impl            12.5.40              hd3aeb46_0    conda-forge
cuda-nvcc-tools           12.5.40              hd3aeb46_0    conda-forge
cuda-nvcc_linux-64        12.5.40              h8a487aa_0    conda-forge
cuda-nvrtc                12.5.40              he02047a_0    conda-forge
cuda-nvtx                 12.5.39              he02047a_0    conda-forge
cuda-nvvm-dev_linux-64    12.5.40              ha770c72_0    conda-forge
cuda-nvvm-impl            12.5.40              h59595ed_0    conda-forge
cuda-nvvm-tools           12.5.40              h59595ed_0    conda-forge
cuda-version              12.5                 hd4f0392_3    conda-forge
cudnn                     8.9.7.29             h092f7fd_3    conda-forge
filelock                  3.14.0             pyhd8ed1ab_0    conda-forge
fsspec                    2024.5.0           pyhff2d567_0    conda-forge
gcc_impl_linux-64         12.3.0               h58ffeeb_7    conda-forge
gcc_linux-64              12.3.0               h6477408_3    conda-forge
gmp                       6.3.0                h59595ed_1    conda-forge
gmpy2                     2.1.5           py312h1d5cde6_1    conda-forge
gxx_impl_linux-64         12.3.0               h2a574ab_7    conda-forge
gxx_linux-64              12.3.0               h4a1b8e8_3    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
importlib-metadata        7.1.0              pyha770c72_0    conda-forge
importlib_metadata        7.1.0                hd8ed1ab_0    conda-forge
jax                       0.4.27             pyhd8ed1ab_0    conda-forge
jaxlib                    0.4.23          cuda120py312h6027bbc_202    conda-forge
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
kernel-headers_linux-64   3.10.0              h4a8ded7_14    conda-forge
ld_impl_linux-64          2.40                 hf3520f5_1    conda-forge
libabseil                 20240116.2      cxx17_h59595ed_0    conda-forge
libblas                   3.9.0           22_linux64_openblas    conda-forge
libcblas                  3.9.0           22_linux64_openblas    conda-forge
libcublas                 12.5.2.13            he02047a_0    conda-forge
libcufft                  11.2.3.18            he02047a_0    conda-forge
libcurand                 10.3.6.39            he02047a_0    conda-forge
libcusolver               11.6.2.40            he02047a_0    conda-forge
libcusparse               12.4.1.24            he02047a_0    conda-forge
libexpat                  2.6.2                h59595ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-devel_linux-64     12.3.0             h0223996_107    conda-forge
libgcc-ng                 13.2.0               h77fa898_7    conda-forge
libgfortran-ng            13.2.0               h69a702a_7    conda-forge
libgfortran5              13.2.0               hca663fb_7    conda-forge
libgomp                   13.2.0               h77fa898_7    conda-forge
libgrpc                   1.62.2               h15f2491_0    conda-forge
libhwloc                  2.10.0          default_h5622ce7_1001    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
liblapack                 3.9.0           22_linux64_openblas    conda-forge
libmagma                  2.7.2                h173bb3b_2    conda-forge
libmagma_sparse           2.7.2                h173bb3b_3    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libnvjitlink              12.5.40              he02047a_0    conda-forge
libopenblas               0.3.27          pthreads_h413a1c8_0    conda-forge
libprotobuf               4.25.3               h08a7969_0    conda-forge
libre2-11                 2023.09.01           h5a48ba9_2    conda-forge
libsanitizer              12.3.0               hb8811af_7    conda-forge
libsqlite                 3.45.3               h2797004_0    conda-forge
libstdcxx-devel_linux-64  12.3.0             h0223996_107    conda-forge
libstdcxx-ng              13.2.0               hc0a3c3a_7    conda-forge
libtorch                  2.3.0           cuda120_h2b0da52_301    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libuv                     1.48.0               hd590300_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.7               hc051c1a_0    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               18.1.6               ha31de31_0    conda-forge
markupsafe                2.1.5           py312h98912ed_0    conda-forge
mkl                       2023.2.0         h84fe81f_50496    conda-forge
ml_dtypes                 0.4.0           py312h1d6d2e6_1    conda-forge
mpc                       1.3.1                hfe3b2da_0    conda-forge
mpfr                      4.2.1                h9458935_1    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
nccl                      2.21.5.1             h3a97aeb_0    conda-forge
ncurses                   6.5                  h59595ed_0    conda-forge
networkx                  3.3                pyhd8ed1ab_1    conda-forge
numpy                     1.26.4          py312heda63a1_0    conda-forge
openssl                   3.3.0                h4ab18f5_3    conda-forge
opt-einsum                3.3.0                hd8ed1ab_2    conda-forge
opt_einsum                3.3.0              pyhc1e730c_2    conda-forge
pip                       24.0               pyhd8ed1ab_0    conda-forge
python                    3.12.3          hab00c5b_0_cpython    conda-forge
python_abi                3.12                    4_cp312    conda-forge
pytorch                   2.3.0           cuda120_py312h26b3cf7_301    conda-forge
re2                       2023.09.01           h7f4b329_2    conda-forge
readline                  8.2                  h8228510_1    conda-forge
scipy                     1.13.1          py312hc2bc53b_0    conda-forge
setuptools                70.0.0             pyhd8ed1ab_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
sympy                     1.12            pypyh9d50eac_103    conda-forge
sysroot_linux-64          2.17                h4a8ded7_14    conda-forge
tbb                       2021.12.0            h297d8ca_1    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
typing_extensions         4.11.0             pyha770c72_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
wheel                     0.43.0             pyhd8ed1ab_1    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zipp                      3.17.0             pyhd8ed1ab_0    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge

Can someone please point out the correct version necessary to get pytorch and jax both with GPU support on CUDA 12 as of July 2024?
I would prefer it to be a standard venv rather than a conda env, but either is fine.

@varadVaidya totally by chance I follow this issue, but in general you may have more success in finding help by using official jax help channels (see https://jax.readthedocs.io/en/latest/beginner_guide.html#finding-help), rather then posting in closed issues.

More on topic, I have no idea about pip/venv with cuda, but for conda the procedure posted in #18032 (comment) is working fine for me (when I originally posted the message I forgot to add the -c conda-forge to ensure it works fine also on anaconda or miniconda installation of conda that use defaults instead of conda-forge, I just fixed that to avoid confusion).

@eliseoe @bebark @shaikalthaf4 By change I just noticed that you added a ๐Ÿ‘Ž๐Ÿฝ reaction to my previous comment, any reason for doing so? Just fyi, authors do not get (at least by default) notifications for post reactions.

@traversaro I found that running your command with conda will install:
jaxlib conda-forge/linux-64::jaxlib-0.4.27-cpu_py312h17e8b90_0
whereas with mamba the correct version is installed:
mamba create -c conda-forge -n jaxpytorch pytorch jax
jaxlib 0.4.27 cuda120py312h4008524_200 conda-forge/linux-64
Perhaps this is why you got 3 thumbs down

@traversaro I found that running your command with conda will install: jaxlib conda-forge/linux-64::jaxlib-0.4.27-cpu_py312h17e8b90_0 whereas with mamba the correct version is installed: mamba create -c conda-forge -n jaxpytorch pytorch jax jaxlib 0.4.27 cuda120py312h4008524_200 conda-forge/linux-64 Perhaps this is why you got 3 thumbs down

@peterch405

Interestingly, in my system with:

root@DESKTOP-T0NQNLN:~# conda info

     active environment : None
            shell level : 0
       user config file : /root/.condarc
 populated config files : /root/miniforge3/.condarc
                          /root/.condarc
          conda version : 24.3.0
    conda-build version : not installed
         python version : 3.10.14.final.0
                 solver : libmamba (default)
       virtual packages : __archspec=1=skylake
                          __conda=24.3.0=0
                          __cuda=12.0=0
                          __glibc=2.39=0
                          __linux=5.15.153.1=0
                          __unix=0=0
       base environment : /root/miniforge3  (writable)
      conda av data dir : /root/miniforge3/etc/conda
  conda av metadata url : None
           channel URLs : https://conda.anaconda.org/conda-forge/linux-64
                          https://conda.anaconda.org/conda-forge/noarch
          package cache : /root/miniforge3/pkgs
                          /root/.conda/pkgs
       envs directories : /root/miniforge3/envs
                          /root/.conda/envs
               platform : linux-64
             user-agent : conda/24.3.0 requests/2.31.0 CPython/3.10.14 Linux/5.15.153.1-microsoft-standard-WSL2 ubuntu/24.04 glibc/2.39 solver/libmamba conda-libmamba-solver/24.1.0 libmambapy/1.5.8
                UID:GID : 0:0
             netrc file : None
           offline mode : False

the command

conda create -n conda-forge -n jaxpytorch pytorch jax

installs the cuda jax, but indeed:

conda create --solver=classic -n conda-forge -n jaxpytorch pytorch jax

installs cpu jax. Perhaps you are using an old conda version that is using the classic solver by default? (You can see this if you report the conda info output, see https://www.anaconda.com/blog/a-faster-conda-for-a-growing-community).

However, even with the classic solver forcing the solver to install the cuda version of jaxlib and pytorch works as expected (even if the classic solver is much slower):

conda create --solver=classic -n conda-forge -n jaxpytorch pytorch=*=cuda* jax jaxlib=*=cuda*

I edited the original comment accordingly.

You are right, I'm using the classic solver:

     active environment : base
    active env location : /home/chovanec/miniconda3
            shell level : 1
       user config file : /home/chovanec/.condarc
 populated config files : /home/chovanec/.condarc
          conda version : 23.1.0
    conda-build version : not installed
         python version : 3.10.8.final.0
       virtual packages : __archspec=1=x86_64
                          __cuda=12.3=0
                          __glibc=2.31=0
                          __linux=5.15.153.1=0
                          __unix=0=0
       base environment : /home/chovanec/miniconda3  (writable)
      conda av data dir : /home/chovanec/miniconda3/etc/conda
  conda av metadata url : None
           channel URLs : https://conda.anaconda.org/bioconda/linux-64
                          https://conda.anaconda.org/bioconda/noarch
                          https://conda.anaconda.org/conda-forge/linux-64
                          https://conda.anaconda.org/conda-forge/noarch
                          https://repo.anaconda.com/pkgs/main/linux-64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/linux-64
                          https://repo.anaconda.com/pkgs/r/noarch
          package cache : /home/chovanec/miniconda3/pkgs
                          /home/chovanec/.conda/pkgs
       envs directories : /home/chovanec/miniconda3/envs
                          /home/chovanec/.conda/envs
               platform : linux-64
             user-agent : conda/23.1.0 requests/2.28.1 CPython/3.10.8 Linux/5.15.153.1-microsoft-standard-WSL2 ubuntu/20.04.6 glibc/2.31
                UID:GID : 1000:1000
             netrc file : None
           offline mode : False

The only way I was able to solve the environment with both JAX and PyTorch on CUDA12 was to install some packages from the nvidia channel:

mamba create -n jaxTorch jaxlib pytorch cuda-nvcc -c conda-forge -c nvidia -c pytorch
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
[cuda(id=0)]
>>> import jaxlib.cuda._versions
>>> jaxlib.cuda._versions.cudnn_get_version()
8902
>>> torch._C._cudnn.getCompileVersion()
(8, 9, 2)

conda list
fyi @traversaro

Thanks for the solution, however i have found a possible bug that the jax numpy cannot initialize an array which size is bigger than (2, 52, 10) with both jax and jaxlib version are 0.4.30, so i have to downgrade the jax version to 0.4.23 and then works just fine, so for the insurance, the command could be like

conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch

python 3.12 is too newer to some commonly used pkgs

Just a curiosity, are you actually getting any packages from the nvidia or pytorch channel? If conda-forge channel is used and you are using strict priority, all the packages you get should come from conda-forge, and so I guess you could drop the -c nvidia -c pytorch from your command. However, you can check this by calling conda list and checking from where packages are installed.

Just a curiosity, are you actually getting any packages from the nvidia or pytorch channel? If conda-forge channel is used and you are using strict priority, all the packages you get should come from conda-forge, and so I guess you could drop the -c nvidia -c pytorch from your command. However, you can check this by calling conda list and checking from where packages are installed.

I'm not sure, maybe later i can do a test,thx for the noticing

Just a curiosity, are you actually getting any packages from the nvidia or pytorch channel? If conda-forge channel is used and you are using strict priority, all the packages you get should come from conda-forge, and so I guess you could drop the -c nvidia -c pytorch from your command. However, you can check this by calling conda list and checking from where packages are installed.

sorry for the late reply, here is the outputs
image
since the jax and jax cuda lib are manually reinstalled by the pypi, i guess yes that the packages are privileged installed from conda-forge :)

Not sure how you can can end up with jax/jaxlib installed via pypi if you just created the environment with conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch, but as a general comment if you are installing something with pip is a good idea not to install it via conda, to avoid conflicts.

Not sure how you can can end up with jax/jaxlib installed via pypi if you just created the environment with conda create -n _env_name_ jaxlib=0.4.23 pytorch cuda-nvcc python=3.11 -c conda-forge -c nvidia -c pytorch, but as a general comment if you are installing something with pip is a good idea not to install it via conda, to avoid conflicts.

In my case, the conflicts comes from the torch and jaxlib stick to different cudnn version, formerly i didn't seek to conda-forge to install the cudatoolkit compatible for both torch and jaxlib. i use the pip command from the official jax documentation btw.

Ok, but in that case it is probably a good idea not to install jax and jaxlib from conda, and only install it from pip.

Ok, but in that case it is probably a good idea not to install jax and jaxlib from conda, and only install it from pip.

i think the only reason for the jax and 'jaxlib suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstall jax and jaxlib from pip

Ok, but in that case it is probably a good idea not to install jax and jaxlib from conda, and only install it from pip.

i think the only reason for the jax and 'jaxlib suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstall jax and jaxlib from pip

But conda has no idea which version of cudnn the jaxlib installed via pip requires. If you want to install cudnn (and even a specific version) with conda, just install cudnn, to avoid problems is tipically useful to avoid to install jax or jaxlib via conda if you are installing it via pip.

Ok, but in that case it is probably a good idea not to install jax and jaxlib from conda, and only install it from pip.

i think the only reason for the jax and 'jaxlib suffix is to make sure the conda-forge could search and install a compatible cudnn version, i did not do the test, so for the insurance, i recommend to annoyingly reinstall jax and jaxlib from pip

But conda has no idea which version of cudnn the jaxlib installed via pip requires. If you want to install cudnn (and even a specific version) with conda, just install cudnn, to avoid problems is tipically useful to avoid to install jax or jaxlib via conda if you are installing it via pip.

you are right, accidentally i use the pip install, and it just found the cudnn version meets the requirement lol.

@traversaro apologies to have to revive this issue, but your solution does not work for me:

$ conda create -c conda-forge -n jaxpytorch pytorch jax
$ conda activate jaxpytorch
$ python
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> import jax
>>> torch.cuda.is_available()
True
>>> jax.devices()
CUDA backend failed to initialize: Unable to use CUDA because of the following issues with CUDA components:
Outdated cuDNN installation found.
Version JAX was built against: 8907
Minimum supported: 9100
Installed version: 8907
The local installation version must be no lower than 9100..(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]

The error message suggests I need cudnn>=9.1.0, however the most recent version on conda-forge appears to be 8.9.7.29.

System info:

Operating System: Ubuntu 22.04.4 LTS
GPU: NVIDIA GeForce GTX 1060 6GB
Graphics Driver: NVIDIA driver metapackage from nvidia-driver-535

mamba list below the fold:

# packages in environment at /home/lucas/mambaforge/envs/jaxpytorch:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
_sysroot_linux-64_curr_repodata_hack 3                   h69a702a_16    conda-forge
binutils_impl_linux-64    2.40                 ha1999f0_7    conda-forge
binutils_linux-64         2.40                 hb3c18ed_0    conda-forge
bzip2                     1.0.8                h4bc722e_7    conda-forge
c-ares                    1.32.3               h4bc722e_0    conda-forge
ca-certificates           2024.7.4             hbcca054_0    conda-forge
cuda-cccl_linux-64        12.5.39              ha770c72_0    conda-forge
cuda-crt-dev_linux-64     12.5.82              ha770c72_0    conda-forge
cuda-crt-tools            12.5.82              ha770c72_0    conda-forge
cuda-cudart               12.5.82              he02047a_0    conda-forge
cuda-cudart-dev           12.5.82              he02047a_0    conda-forge
cuda-cudart-dev_linux-64  12.5.82              h85509e4_0    conda-forge
cuda-cudart-static        12.5.82              he02047a_0    conda-forge
cuda-cudart-static_linux-64 12.5.82              h85509e4_0    conda-forge
cuda-cudart_linux-64      12.5.82              h85509e4_0    conda-forge
cuda-cupti                12.5.82              he02047a_0    conda-forge
cuda-driver-dev_linux-64  12.5.82              h85509e4_0    conda-forge
cuda-nvcc                 12.5.82              hcdd1206_0    conda-forge
cuda-nvcc-dev_linux-64    12.5.82              ha770c72_0    conda-forge
cuda-nvcc-impl            12.5.82              hd3aeb46_0    conda-forge
cuda-nvcc-tools           12.5.82              hd3aeb46_0    conda-forge
cuda-nvcc_linux-64        12.5.82              h8a487aa_0    conda-forge
cuda-nvrtc                12.5.82              he02047a_0    conda-forge
cuda-nvtx                 12.5.82              he02047a_0    conda-forge
cuda-nvvm-dev_linux-64    12.5.82              ha770c72_0    conda-forge
cuda-nvvm-impl            12.5.82              h59595ed_0    conda-forge
cuda-nvvm-tools           12.5.82              h59595ed_0    conda-forge
cuda-version              12.5                 hd4f0392_3    conda-forge
cudnn                     8.9.7.29             h092f7fd_3    conda-forge
filelock                  3.15.4             pyhd8ed1ab_0    conda-forge
fsspec                    2024.6.1           pyhff2d567_0    conda-forge
gcc_impl_linux-64         13.3.0               hfea6d02_0    conda-forge
gcc_linux-64              13.3.0               hc28eda2_0    conda-forge
gmp                       6.3.0                hac33072_2    conda-forge
gmpy2                     2.1.5           py312h1d5cde6_1    conda-forge
gxx_impl_linux-64         13.3.0               hffce095_0    conda-forge
gxx_linux-64              13.3.0               h6834431_0    conda-forge
icu                       75.1                 he02047a_0    conda-forge
importlib-metadata        8.2.0              pyha770c72_0    conda-forge
importlib_metadata        8.2.0                hd8ed1ab_0    conda-forge
jax                       0.4.31             pyhd8ed1ab_0    conda-forge
jaxlib                    0.4.30          cuda120py312h4008524_200    conda-forge
jinja2                    3.1.4              pyhd8ed1ab_0    conda-forge
kernel-headers_linux-64   3.10.0              h4a8ded7_16    conda-forge
ld_impl_linux-64          2.40                 hf3520f5_7    conda-forge
libabseil                 20240116.2      cxx17_he02047a_1    conda-forge
libblas                   3.9.0           23_linux64_openblas    conda-forge
libcblas                  3.9.0           23_linux64_openblas    conda-forge
libcublas                 12.5.3.2             he02047a_0    conda-forge
libcufft                  11.2.3.61            he02047a_0    conda-forge
libcurand                 10.3.6.82            he02047a_0    conda-forge
libcusolver               11.6.3.83            he02047a_0    conda-forge
libcusparse               12.5.1.3             he02047a_0    conda-forge
libexpat                  2.6.2                h59595ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-devel_linux-64     13.3.0             h84ea5a7_100    conda-forge
libgcc-ng                 14.1.0               h77fa898_0    conda-forge
libgfortran-ng            14.1.0               h69a702a_0    conda-forge
libgfortran5              14.1.0               hc5f4f2c_0    conda-forge
libgomp                   14.1.0               h77fa898_0    conda-forge
libgrpc                   1.62.2               h15f2491_0    conda-forge
libhwloc                  2.11.1          default_hecaa2ac_1000    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
liblapack                 3.9.0           23_linux64_openblas    conda-forge
libmagma                  2.7.2                h173bb3b_2    conda-forge
libmagma_sparse           2.7.2                h173bb3b_3    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libnvjitlink              12.5.82              he02047a_0    conda-forge
libopenblas               0.3.27          pthreads_hac2b453_1    conda-forge
libprotobuf               4.25.3               h08a7969_0    conda-forge
libre2-11                 2023.09.01           h5a48ba9_2    conda-forge
libsanitizer              13.3.0               heb74ff8_0    conda-forge
libsqlite                 3.46.0               hde9e2c9_0    conda-forge
libstdcxx-devel_linux-64  13.3.0             h84ea5a7_100    conda-forge
libstdcxx-ng              14.1.0               hc0a3c3a_0    conda-forge
libtorch                  2.3.1           cuda120_h2b0da52_300    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libuv                     1.48.0               hd590300_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.7               he7c6b58_4    conda-forge
libzlib                   1.3.1                h4ab18f5_1    conda-forge
llvm-openmp               18.1.8               hf5423f3_0    conda-forge
markupsafe                2.1.5           py312h98912ed_0    conda-forge
mkl                       2023.2.0         h84fe81f_50496    conda-forge
ml_dtypes                 0.4.0           py312h1d6d2e6_1    conda-forge
mpc                       1.3.1                hfe3b2da_0    conda-forge
mpfr                      4.2.1                h38ae2d0_2    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
nccl                      2.22.3.1             hbc370b7_1    conda-forge
ncurses                   6.5                  h59595ed_0    conda-forge
networkx                  3.3                pyhd8ed1ab_1    conda-forge
numpy                     2.0.1           py312h1103770_0    conda-forge
openssl                   3.3.1                h4bc722e_2    conda-forge
opt-einsum                3.3.0                hd8ed1ab_2    conda-forge
opt_einsum                3.3.0              pyhc1e730c_2    conda-forge
pip                       24.0               pyhd8ed1ab_0    conda-forge
python                    3.12.4          h194c7f8_0_cpython    conda-forge
python_abi                3.12                    4_cp312    conda-forge
pytorch                   2.3.1           cuda120_py312h26b3cf7_300    conda-forge
re2                       2023.09.01           h7f4b329_2    conda-forge
readline                  8.2                  h8228510_1    conda-forge
scipy                     1.14.0          py312hc2bc53b_1    conda-forge
setuptools                71.0.4             pyhd8ed1ab_0    conda-forge
sleef                     3.6.1                h3400bea_1    conda-forge
sympy                     1.13.0          pypyh2585a3b_103    conda-forge
sysroot_linux-64          2.17                h4a8ded7_16    conda-forge
tbb                       2021.12.0            h434a139_3    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
typing_extensions         4.12.2             pyha770c72_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
wheel                     0.43.0             pyhd8ed1ab_1    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zipp                      3.19.2             pyhd8ed1ab_0    conda-forge
zstd                      1.5.6                ha6fb4c9_0    conda-forge

UPDATE: jax=0.4.28 appears to does work, so this looks like a bug introduced recently in JAX. cc @hawkinsp

@lucascolley thanks for reporting the issue, can you please open an issue in https://github.com/conda-forge/jaxlib-feedstock and tag me there? Thanks!

@lucascolley thanks for reporting the issue, can you please open an issue in https://github.com/conda-forge/jaxlib-feedstock and tag me there? Thanks!

Thanks @lucascolley, indeed it seems a regression in the conda-forge jax package 0.4.31, I opened conda-forge/jaxlib-feedstock#277 to track the problem.