`CUDA_ERROR_ILLEGAL_ADDRESS`
Opened this issue · 1 comments
Description
Hello,
we've got a code (that we can share) that crashes with the error CUDA_ERROR_ILLEGAL_ADDRESS
on jax:CUDA, which seems like we've hit some XLA compilation bug?
The code is complex (And full of .at[].set()
), and attempts to produce a MWE fail because small changes to the python code make the error disappear.
We can share the script (and a requirements.txt file) to reproduce it, but maybe you just need the HLO code to reproduce it?
Let us know if we can provide something that can be useful to debug it.
cc @AmedSho97
E1120 10:59:55.894873 35984 pjrt_stream_executor_client.cc:3067] Execution of replica 0 failed: INTERNAL: Failed to retrieve branch_index value on stream 0x5583cca477e0: CUDA error: Could not synchronize CUDA stream: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered.
Traceback (most recent call last):
File "/mnt/beegfs/workdir/ahmedeo.shokry/Hubbard_4x4_BF_conv_U_3_6_trans_inv/test_trans_HFSD.py", line 146, in <module>
print(vstate.expect(H))
^^^^^^^^^^^^^^^^
File "/mnt/beegfs/workdir/ahmedeo.shokry/mambaforge/envs/netket_sharding_local/lib/python3.11/site-packages/netket/utils/timing.py", line 230, in timed_function
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/mnt/beegfs/workdir/ahmedeo.shokry/mambaforge/envs/netket_sharding_local/lib/python3.11/site-packages/netket/vqs/mc/mc_state/state.py", line 631, in expect
return expect(self, O, self.chunk_size)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/beegfs/workdir/ahmedeo.shokry/mambaforge/envs/netket_sharding_local/lib/python3.11/site-packages/plum/function.py", line 383, in __call__
return _convert(method(*args, **kw_args), return_type)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/beegfs/workdir/ahmedeo.shokry/mambaforge/envs/netket_sharding_local/lib/python3.11/site-packages/netket/vqs/mc/mc_state/expect.py", line 111, in expect
return _expect(
^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to retrieve branch_index value on stream 0x5583cca477e0: CUDA error: Could not synchronize CUDA stream: CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
2024-11-20 10:59:55.937334: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:1144] failed to unload module 0x5583eb358f00; leaking: INTERNAL: CUDA error: : CUDA_ERROR_ILLEGAL_ADDRESS: an illegal memory access was encountered
System info (python version, jaxlib version, accelerator, etc.)
Note: we were running on a gpu node with 4 A100 GPUs
jax: 0.4.33
jaxlib: 0.4.33
numpy: 2.0.2
python: 3.11.10 | packaged by conda-forge | (main, Sep 22 2024, 14:10:38) [GCC 13.3.0]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='cholesky-login01', release='3.10.0-1062.el7.x86_64', version='#1 SMP Wed Aug 7 18:08:02 UTC 2019', machine='x86_64')
Well, we'd need to reproduce it somehow. It's possible the HLO will suffice to reproduce it, so we could certainly try that? Can you share an HLO dump? (Set XLA_FLAGS=--xla_dump_to=/somewhere
and zip up /somewhere
and attach it to the issue.)