jax-ml/jax

CUDA12 plugin segfaults when older version of JAX is installed

Opened this issue · 0 comments

Description

Since version 0.4.32 of the jax-cuda12-pjrt and jax-cuda12-plugin packages, installing an older version of jax/jaxlib generates a segfault as soon as an array is instantiated:

>>> import jax
>>> jax.numpy.array(0)
Segmentation fault (core dumped)

This happens even when a GPU is not used (e.g. JAX_PLATFORMS=cpu).

Previously (<=0.4.31), the plugin would raise an exception instead, warning of the version mismatch:

RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

Setting the platform to CPU would also allow you to use JAX normally despite the mismatched CUDA plugin version.


You might ask why would you install mismatched versions of the jax and cuda plugin packages, but because JAX doesn't manage the versions of these dependencies automatically, it can easily happen that JAX+CUDA is installed into an environment first, and that another package requests a lower JAX version from pip. It is then very difficult for the user to understand what went wrong due to the segfault.

System info (python version, jaxlib version, accelerator, etc.)

(import jax; jax.print_environment_info() crashes)

Linux
Python 3.10.12

jax-0.4.28
jaxlib-0.4.28
jax-cuda12-pjrt 0.4.33
jax-cuda12-plugin-0.4.33