CUDA12 plugin segfaults when older version of JAX is installed
Opened this issue · 0 comments
Description
Since version 0.4.32
of the jax-cuda12-pjrt
and jax-cuda12-plugin
packages, installing an older version of jax
/jaxlib
generates a segfault as soon as an array is instantiated:
>>> import jax
>>> jax.numpy.array(0)
Segmentation fault (core dumped)
This happens even when a GPU is not used (e.g. JAX_PLATFORMS=cpu
).
Previously (<=0.4.31
), the plugin would raise an exception instead, warning of the version mismatch:
RuntimeError: Unable to initialize backend 'cuda': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.54) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
Setting the platform to CPU would also allow you to use JAX normally despite the mismatched CUDA plugin version.
You might ask why would you install mismatched versions of the jax and cuda plugin packages, but because JAX doesn't manage the versions of these dependencies automatically, it can easily happen that JAX+CUDA is installed into an environment first, and that another package requests a lower JAX version from pip. It is then very difficult for the user to understand what went wrong due to the segfault.
System info (python version, jaxlib version, accelerator, etc.)
(import jax; jax.print_environment_info()
crashes)
Linux
Python 3.10.12
jax-0.4.28
jaxlib-0.4.28
jax-cuda12-pjrt 0.4.33
jax-cuda12-plugin-0.4.33