clami66/AF_unmasked

Out of memory issue

builab opened this issue · 3 comments

I tried to do a 12-mer complex from 3 proteins (Total 4500 AA). Having the template of almost everything except 6x 170 amino acids domains. Limit all the xyz_max_hits to 1.
I ran on a 4090 RTX.
I got this error.

I0315 22:53:33.550827 140012347488064 run_alphafold.py:230] Running model model_1_multimer_v3_pred_1 on CpaFGH
I0315 22:53:33.551195 140012347488064 model.py:165] Running predict with shape(feat) = {'aatype': (4539,), 'residue_index': (4539,), 'seq_length': (), 'msa': (512, 4539), 'num_alignments': (), 'template_aatype': (4, 4539), 'template_all_atom_mask': (4, 4539, 37), 'template_all_atom_positions': (4, 4539, 37, 3), 'asym_id': (4539,), 'sym_id': (4539,), 'entity_id': (4539,), 'deletion_matrix': (512, 4539), 'deletion_mean': (4539,), 'all_atom_mask': (4539, 37), 'all_atom_positions': (4539, 37, 3), 'assembly_num_chains': (), 'entity_mask': (4539,), 'num_templates': (), 'cluster_bias_mask': (512,), 'bert_mask': (512, 4539), 'seq_mask': (4539,), 'msa_mask': (512, 4539)}
2024-03-15 22:54:07.499744: W external/org_tensorflow/tensorflow/tsl/framework/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 9.84GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
Traceback (most recent call last):
File "/storage/software/AF_unmasked/run_alphafold.py", line 504, in
app.run(main)
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/storage/software/AF_unmasked/run_alphafold.py", line 479, in main
predict_structure(
File "/storage/software/AF_unmasked/run_alphafold.py", line 238, in predict_structure
prediction_result = model_runner.predict(processed_feature_dict,
File "/storage/software/AF_unmasked/alphafold/model/model.py", line 167, in predict
result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat)
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/_src/api.py", line 622, in cache_miss
execute = dispatch.xla_call_impl_lazy(fun, *tracers, **params)
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/linear_util.py", line 303, in memoized_fun
ans = call(fun, *args)
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/_src/dispatch.py", line 996, in compile
self._executable = XlaCompiledComputation.from_xla_computation(
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1194, in from_xla_computation
compiled = compile_or_get_cached(backend, xla_computation, options,
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1077, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/storage/software/anaconda3/envs/AF_unmasked/lib/python3.9/site-packages/jax/_src/dispatch.py", line 1012, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 10565271552 bytes.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The strange thing is that it said it cannot allocate 10.6GB. What is the limit for 4090 (24GB card) when limiting MSA to 1?

Just want to add to it. Changing the MSA from 1 to 100 to 200 doesn't change the error message. The amount of RAM lacking is the same. Why would that happens?

Hi again,

Even when the MSA contains few sequences, this is 0-padded by AlphaFold to a minimum depth of 512, so the minimum size for the MSA in your case is 4539x512. I have not tested whether disabling the padding is possible and if it will decrease the memory footprint, I will look into it but I don't have the time now.

So what you can do for the time being is try and use multiple GPUs at the same time (and enable unified memory) so that you have access to more VRAM, or run on larger GPUs (I know that this particular target size will work fine on a A100 GPU).