google-deepmind/ferminet

Jax error running on A100 GPU (everything is okay on CPU)

Closed this issue · 2 comments

Hi,

I got an error on the train.py, line 229 new_params, state, stats = optimizer.step(......)

The error code is shown below:

2022-04-14 12:46:59.552761: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 1 failed: INTERNAL: CustomCall failed: jaxlib/cusolver_kernels.cc:44: operation cusolverDnCreate(&handle) failed: cuSolver internal error
2022-04-14 12:47:09.554416: F external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2288] Replicated computation launch failed, but not all replicas terminated. Aborting process to work around deadlock. Failure message (there may have been multiple failures, see the error log for all failures):

CustomCall failed: jaxlib/cusolver_kernels.cc:44: operation cusolverDnCreate(&handle) failed: cuSolver internal error
Fatal Python error: Aborted

I didn't get any error running on CPU. But on GPU I always get this error.
Could you help me to solve this problem? Thank you.

The version I use is Jax 0.3.5 with Jaxlib 0.3.5 with cuda11.cudnn82

This looks like a JAX/cuDNN install or usage issue than something specific to our code. My only suggestion is to try lowering the fraction of memory JAX preallocates (e.g. to 0.8) - https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html.