This repository provides is the associated code for "Jointly Embedding Protein Structures and Sequences through Residue Level Alignment" by Foster Birnbaum, Saachi Jain, Aleksander Madry, and Amy E. Keating.
The rla_env.yml file specifies the needed requirements.
Model weights and data are available to download here. The model weights folder should be downloaded to the home directory. The data are provided as a zipped folder containing the train/validation/test datasplits as WebDatasets.
Example inference data are provided in the example_data folder.
To package a .pdf file into the WebDataset format that is most easily read by RLA, follow this example.
Once the model weights folder is downloaded, the value of the model_dir = /c/example/path
line in each example notebook must be changed to reflect the path to the weights. Additionally, if the computer you are running the notebook on is offline, the args_dict['arch'] = '/c/example/path'
line must be changed to reflect the path to ESM-2 as downloaded by the transformers module. If your computer is online, the args_dict['arch'] = '/c/example/path'
should be deleted.
The core idea behind using RLA is to pass the sequence and structure of a protein through the corresponding tracks of RLA (RLA-ESM and RLA-COORDinator) to generate sequence and structure embeddings and to calculate the residue-level cosine similarity in the resulting embeddings. The cosine similarities are averaged to generate an RLA score that represents the sequence-structure compatability in the input protein. The following is an example case of generate an RLA score, assuming val_loader
is a PyTorch dataloader generated by processing the WebDataset as shown in the example above.
import src.data_utils as data_utils
## Get sequence and structure embeddings from RLA
def get_seq_and_struct_features(model, tokenizer, batch):
seq_batch, coords_batch = batch
seqs = seq_batch['string_sequence']
text_inp = tokenizer(seqs, return_tensors='pt', padding=True, truncation=True, max_length=1024+2)
text_inp['position_ids'] = seq_batch['pos_embs'][0]
text_inp = {k: v.to('cuda') for k, v in text_inp.items()}
coord_data = data_utils.construct_gnn_inp(coords_batch, device='cuda', half_precision=True)
gnn_features, text_features, logit_scale = model(text_inp, coord_data) # Get features
new_text_features, _, new_text_mask = data_utils.postprocess_text_features(
text_features=text_features,
inp_dict=text_inp,
tokenizer=tokenizer,
placeholder_mask=seq_batch['placeholder_mask'][0])
return {
'text': new_text_features, # text feature
'gnn': gnn_features, # gnn feature
'seq_mask_with_burn_in': seq_batch['seq_loss_mask'][0], # sequence mask of what's supervised
'coord_mask_with_burn_in': coords_batch['coords_loss_mask'][0], # coord mask of what's supervised
'seq_mask_no_burn_in': new_text_mask.bool(), # sequence mask of what's valid (e.g., not padded)
'coord_mask_no_burn_in': coords_batch['coords'][1], # coord mask of what's valid
}
all_scores = []
for i, batch in enumerate(val_loader):
with torch.no_grad():
with autocast(dtype=torch.float16):
output_dict = get_seq_and_struct_features(trained_model, tokenizer, batch)
text_feat = output_dict['text']
gnn_feat = output_dict['gnn'][:, :text_feat.shape[1]] # Remove tail padding
scores = (text_feat.unsqueeze(2) @ gnn_feat.unsqueeze(-1)).squeeze(-1).squeeze(-1)
scores = (scores * output_dict['seq_mask_no_burn_in'].float()).sum(1)/output_dict['seq_mask_no_burn_in'].sum(1) # Calculate RLA score
all_scores.append(scores.cpu())
An example of how to use RLA to rank candidate structures is provided here. The example ranks hundreds of decoy structures for 2 real structures from the PDB and evaluates the comparison by calculating a correlation to the decoy TM-scores. The data are sourced from Roney and Ovchinnikov, 2022.
An example of how to use RLA to predict the effect of mutations is provided here. The example predicts the effects of thousands of single and double amino acid substitutions on the stability of single chain proteins and compares the predictions to experimentally observed values. The data are sourced from Tsuboyama et al., 2023.
An example of how to use RLA to predict the contacts between 2 residues in a protein is provided here.
To train the model, use the clip_main.py
script. For example
python clip_main.py --config dataset_configs/full_pdb.yaml --training.exp_name experiment_name --model.coordinator_hparams terminator_configs/standard.json