jax-ml/jax

`CUDA_ERROR_ILLEGAL_ADDRESS`

Opened this issue · 1 comments

Description

Hello,

we've got a code (that we can share) that crashes with the error CUDA_ERROR_ILLEGAL_ADDRESS on jax:CUDA, which seems like we've hit some XLA compilation bug?
The code is complex (And full of .at[].set()), and attempts to produce a MWE fail because small changes to the python code make the error disappear.

We can share the script (and a requirements.txt file) to reproduce it, but maybe you just need the HLO code to reproduce it?
Let us know if we can provide something that can be useful to debug it.

cc @AmedSho97

E1120 10:59:55.894873   35984 pjrt_stream_executor_client.cc:3067] Execution of replica 0 failed: INTERNAL: Failed to retrieve branch_index value on stream 0x5583cca477e0: CUDA error: Could not synchronize CUDA stream: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered.
Traceback (most recent call last):
  File "/mnt/beegfs/workdir/ahmedeo.shokry/Hubbard_4x4_BF_conv_U_3_6_trans_inv/test_trans_HFSD.py", line 146, in <module>
    print(vstate.expect(H))
          ^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/ahmedeo.shokry/mambaforge/envs/netket_sharding_local/lib/python3.11/site-packages/netket/utils/timing.py", line 230, in timed_function
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/ahmedeo.shokry/mambaforge/envs/netket_sharding_local/lib/python3.11/site-packages/netket/vqs/mc/mc_state/state.py", line 631, in expect
    return expect(self, O, self.chunk_size)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/ahmedeo.shokry/mambaforge/envs/netket_sharding_local/lib/python3.11/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
                    ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/beegfs/workdir/ahmedeo.shokry/mambaforge/envs/netket_sharding_local/lib/python3.11/site-packages/netket/vqs/mc/mc_state/expect.py", line 111, in expect
    return _expect(
           ^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to retrieve branch_index value on stream 0x5583cca477e0: CUDA error: Could not synchronize CUDA stream: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
2024-11-20 10:59:55.937334: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1144] failed to unload module 0x5583eb358f00; leaking: INTERNAL: CUDA error: : CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered

System info (python version, jaxlib version, accelerator, etc.)

Note: we were running on a gpu node with 4 A100 GPUs

jax: 0.4.33
jaxlib: 0.4.33
numpy: 2.0.2
python: 3.11.10 | packaged by conda-forge | (main, Sep 22 2024, 14:10:38) [GCC 13.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='cholesky-login01', release='3.10.0-1062.el7.x86_64', version='#1 SMP Wed Aug 7 18:08:02 UTC 2019', machine='x86_64')

Well, we'd need to reproduce it somehow. It's possible the HLO will suffice to reproduce it, so we could certainly try that? Can you share an HLO dump? (Set XLA_FLAGS=--xla_dump_to=/somewhere and zip up /somewhere and attach it to the issue.)