HeliXonProtein/OmegaFold

OOM issue when working with gradient

WANG-CR opened this issue · 1 comments

Description

Hello, I encountered an issue when I wanted to train on your model.
I simply remove the python decorator @torch.no_grad(), and this makes 2 problems:

  1. The softmax function with the argument in_place=True is not differentiable. I fixed this problem by inputting in_place=False explicitly.
  2. Out Of Memory. I tried to truncate the amino acid sequence, but it is always OOM until there are only 15 residues.

For more details, I am using a V100 32G GPU. Would you like to share how you solved these problems during training? Especially how many resources have been used to train OmegaFold, and what is the necessary GPU RAM to fit in the whole model?

To Reproduce

  1. Remove the python decorator @torch.no_grad() in omegafold/__main__.py
  2. Execute python main.py INPUT_FILE.fasta OUTPUT_DIRECTORY

Hi,

This code is specialized for inference only. As you can see, we have tried really hard to make it run under moderate GRAM requirement, at least for inference. During the entire training process, we have used a couple of hundreds Nvidia A100 with 80G of GRAM, so yeah it is indeed costly, but still we need to use gradient rematerialization as well.