THUNLP-MT/dyMEAN

checkpoints-related questions

TangYiChing opened this issue · 9 comments

Hello,

I am incredibly interested in your work and would like to explore it more for our research. Could you help me figure out two questions regarding checkpoint?

  1. What is the difference between using struct_prediction.ckpt and the other checkpoints for the structure prediction task?
  2. I called torch.load() outside the dyMEAN folder and then showed this error: ModuleNotFoundError: No module named 'models.' What would you suggest if I load a checkpoint in a customized script in my working directory instead of dyMEAN?

Hi, thank you for your interest in our work! Here are my answers to your questions:

  1. struct_prediction.ckpt is used for predicting the structure of the docked antibody-antigen complexes given the epitope on the antigen and the sequence of the heavy chain & light chain of the antibody. Other checkpoints are used for designing where the CDRs of the antibody are unknown. Therefore, only struct_prediction.ckpt should be used for structure prediction task.
  2. This is because I directly used torch.save() to save the model and the loading requires dependency on the codes for defining the model. I suggest two ways to solve the problem:
    • simple but temporary solution: You can specify the environment variable PYTHONPATH=/path/to/dyMEAN in your script to add the repo to the library path of python. For example, suppose you have cloned the repo to /data/dyMEAN, you can directly set PYTHONPATH=/data/dyMEAN. However, if there is any naming conflicts in your working repo (e.g. your codes also has a folder named utils as dyMEAN), this solution might cause problem.
    • complicated but forever solution: You can load the model in the dyMEAN folder, than use torch.save(model.state_dict()) to save the weights of the model only. Then, please copy the folder model to your working repo, and write codes for loading the weights like:
        from models import dyMEANModel
    
        model = dyMEANModel(
          embed_size=64,
          hidden_size=128,
          n_channel=14,
          num_classes=20,
          struct_only=True,
          bind_dist_cutoff=6.6
        )
        model.load_state_dict(torch.load('/path/to/state_dict'))

Hi, thank you for your interest in our work! Here are my answers to your questions:

  1. struct_prediction.ckpt is used for predicting the structure of the docked antibody-antigen complexes given the epitope on the antigen and the sequence of the heavy chain & light chain of the antibody. Other checkpoints are used for designing where the CDRs of the antibody are unknown. Therefore, only struct_prediction.ckpt should be used for structure prediction task.

  2. This is because I directly used torch.save() to save the model and the loading requires dependency on the codes for defining the model. I suggest two ways to solve the problem:

    • simple but temporary solution: You can specify the environment variable PYTHONPATH=/path/to/dyMEAN in your script to add the repo to the library path of python. For example, suppose you have cloned the repo to /data/dyMEAN, you can directly set PYTHONPATH=/data/dyMEAN. However, if there is any naming conflicts in your working repo (e.g. your codes also has a folder named utils as dyMEAN), this solution might cause problem.
    • complicated but forever solution: You can load the model in the dyMEAN folder, than use torch.save(model.state_dict()) to save the weights of the model only. Then, please copy the folder model to your working repo, and write codes for loading the weights like:
        from models import dyMEANModel
    
        model = dyMEANModel(
          embed_size=64,
          hidden_size=128,
          n_channel=14,
          num_classes=20,
          struct_only=True,
          bind_dist_cutoff=6.6
        )
        model.load_state_dict(torch.load('/path/to/state_dict'))

Do these parameters in dyMEANModel apply to all checkpoints?

Mostly yes. Also, the exact configures corresponding to each checkpoint can be found in README. Therefore you can check the configs for these parameters before loading certain checkpoints.

Thank you for helping me resolve the problems!

You are welcome. Feel free to post questions here if you encounter any further problems.

Could you please verify if the customized way is precisely what you initially meant?

The typical dyMEAN approach to load a checkpoint

# load model
model = torch.load(args.ckpt, map_location='cpu')
print(f'Model type: {type(model)}')
device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}')
model.to(device)
model.eval()

A customized way of using dyMEAN's model weights (without providing a checkpoint file)
This approach requires the model architecture (e.g., dyMEANModel) and corresponding configurations (e.g., struct_prediction.json) for each checkpoints (e.g., struct_prediction.ckpt).

#1. load model settings

from models import dyMEANModel
model = dyMEANModel(
      embed_size=64,
      hidden_size=128,
      n_channel=14,
      num_classes=20,
      struct_only=True,
      bind_dist_cutoff=6.6
    )

#2. save the wights only
weight_path = 'model_weights.pt'
torch.save( model.state_dict(), weight_path)

#3. load model weights for inference
model.load_state_dict( torch.load(weight_path) )
device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}')
model.to(device)
model.eval()

Following the discussion above, I was able to load model weights and run the structure_prediction.py to generate pdb files. However, an unexpected error showed when using the relax function (i.e., enable_openmm_relax=True). Would you happen to know what causes this error and how to fix it?

2024-03-31 22:37:55::INFO::Openmm relaxing...
Traceback (most recent call last):
  File "./scripts/test_6ml8_4o58.py", line 29, in <module>
    structure_prediction(
  File "./scripts/test_6ml8_4o58.py", line 7, in structure_prediction
    design(ckpt, gpu, pdbs, epitope_defs, seqs, out_dir, identifiers,
  File "/home/antibody-engineering/scripts/dyMEAN/api/design.py", line 212, in design
    openmm_relax(mod_pdb, mod_pdb,
  File "/home/antibody-engineering/scripts/dyMEAN/utils/relax.py", line 74, in openmm_relax
    modeller.addHydrogens(force_field)
  File "/root/miniconda3/envs/dyMEAN/lib/python3.8/site-packages/openmm/app/modeller.py", line 975, in addHydrogens
    delta *= 0.1*nanometer/norm(delta)
  File "/root/miniconda3/envs/dyMEAN/lib/python3.8/site-packages/openmm/unit/quantity.py", line 408, in __truediv__
    return (self/other._value) / other.unit
  File "/root/miniconda3/envs/dyMEAN/lib/python3.8/site-packages/openmm/unit/quantity.py", line 411, in __truediv__
    return self * pow(other, -1.0)
ZeroDivisionError: 0.0 cannot be raised to a negative power

Even enable_openmm=False, the generated pdb file looks "inaccurate": image

I tried to do the same thing like your python -m api.structure_prediction example. Using my data, I can generate a pdb file like this (see how different it is from the figure above)
image

Except model loading, I didn't do any modifications to the design.py. Here's how I modify the design.py for my case.

from

# load model
device = torch.device('cpu' if gpu == -1 else f'cuda:{gpu}')
model = torch.load(ckpt)
model.to(device)
model.eval()

to

# load model
model = import_dyMEAN_checkpoint(ckpt)
model.to(device)
model.eval()
 

, in which import_dyMEAN_checkpoint(ckpt) goes through model and weights loading as you suggested:

import config_loader as cl

# import dyMEAN
from dyMEAN.models.dyMEAN_model import dyMEANModel

def import_dyMEAN_checkpoint(ckpt):
    """
    :param ckpt: string representing checkpoint download from https://github.com/THUNLP-MT/dyMEAN/releases
    :return model: model weights
    """
    # get model settings
    valid_key = {'struct_prediction': 'struct_prediction',
                 'cdrh3_design': 'single_cdr_design',
                 'cdrh3_opt': 'single_cdr_optimize'}
    ckpt_dict = {}
    ckpt_dict['struct_prediction'] = dyMEANModel(embed_size=64,
                                                 hidden_size=128,
                                                 n_channel=14,
                                                 num_classes=20,
                                                 struct_only=True,
                                                 bind_dist_cutoff=6.6,
                                                 )
    ckpt_dict['single_cdr_design'] = dyMEANModel(embed_size=64,
                                                 hidden_size=128,
                                                 n_channel=14,
                                                 num_classes=14,
                                                 struct_only=False,
                                                 bind_dist_cutoff=6.6,
                                                 )
    # path settings
    config = cl.read_config()
    proj_path = config['Project']['project_path']
    tmp_path = os.path.join(proj_path, config['Settings']['cache_dir'])
    save_path = f'{tmp_path}/{basename}_model_weights.ckpt'

    # load the model and save the weight only
    model = ckpt_dict[basename]
    torch.save( model.state_dict(), save_path )

    # load model weights
    model.load_state_dict( torch.load(save_path, map_location='cpu') )
    return model

Hi! From the codes it looks like you are saving random initialized weights and reload it, thus the output is random noise. For the codes below:

#1. load model settings

from models import dyMEANModel
model = dyMEANModel(
      embed_size=64,
      hidden_size=128,
      n_channel=14,
      num_classes=20,
      struct_only=True,
      bind_dist_cutoff=6.6
    )

#2. save the wights only
weight_path = 'model_weights.pt'
torch.save( model.state_dict(), weight_path)

#3. load model weights for inference
model.load_state_dict( torch.load(weight_path) )
device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}')
model.to(device)
model.eval()

Step 2 should be done in the official repo by model = torch.load('/path/to/struct_prediction.ckpt') and torch.save(model.state_dict(), weight_path). Then you can load the saved weights in Step 3.

Further, since you managed to put all dependencies into your own code base, maybe you can also try directly loading the checkpoint like torch.load('/path/to/struct_prediction.ckpt').

As for the openmm problem, it may arise from the very inaccurate structure. If you succeed in outputing a reasonable structure like Figure 2, and the problem still exists, we can further take a deep look at this problem.

So, I used the "simple but temporary" solution and added the dyMEAN path in my customized script outside the dyMEAN folder. For my use case, this approach turns out to be the best solution.

# Add the dyMEAN package to the Python path
import os, sys
current_dir = os.path.dirname(os.path.realpath(__file__))
dyMEAN_dir = os.path.abspath(os.path.join(current_dir, 'dyMEAN'))
sys.path.append(dyMEAN_dir)

# Now you can import your modules from dyMEAN package
from dyMEAN.api.design import design
from dyMEAN.api.optimize import optimize, ComplexSummary