openxla/xla

Support builds with cuDNN 9

joelberkeley opened this issue · 1 comments

At the moment, builds use the tensorflow build docker image. That uses cuDNN 8, and as such there doesn't appear to be a way to run with cuDNN 9. cuDNN 8 is now archived on the Nvidia website.

You can use the containers from JAX-Toolbox as they have cudnn 9:
https://github.com/NVIDIA/JAX-Toolbox

docker pull ghcr.io/nvidia/jax:jax

This include the nightly JAX/XLA version.