patrickbryant1/Umol

[question]: Around training the model.

alexanderbonnet opened this issue · 5 comments

Hi! First of all, thanks for making your work so readily available.

I am looking to get a PyTorch reproduction of the repository going. I have not run into problems for inference (adapting from OpenFold and converting weights), but am running into a couple of challenges at train time, and wondered if you could help me understand some implementation details.


I see in the make_uniform function of the predict.py file that a comment mentions that the amino acid type if set to glycine, but the zero index that remains actually sets the amino acid to alanine. Wouldn't this matter for the pseudo_beta_fn and the inclusion of the ligand in the distogram loss?

# 20, where 20 is 'X'. Put 0 (GLY) for ligand atoms - will take care of lots of mapping inside the net


In the folding.py for the backbone_loss, a "atom14_gt_exists_protein" feature is built. I presume this contains atom masks for the protein only? As opposed to "atom14_gt_exists" which must contain atoms for the protein and ligand.

backbone_mask_protein = batch['atom14_gt_exists_protein'][:,0]

What about in the sidechain_loss?

flat_frames_mask = jnp.reshape(batch['rigidgroups_gt_exists']*batch['rigidgroups_gt_protein_exists'], [-1])


Thanks for your help!

Hi,

I would like to explain this in detail, but don't have time right now. I will provide a better answer in the first weeks of July.

Right now what I can say is that comment is probably wrong. What matters is that each 'ligand atom token' can be mapped to a CA and that the masks use only CA for the losses. Any amino acid selection will take care of this.

Hope this helps (somewhat).

Best,

Patrick

Thanks for the quick answer!

You may disregard the second portion of the question, my issues were due to poor indexing on my part in one of the frame aligned point error losses. All looks good now and I am getting expected behavior during training.

For the distogram loss, it still looks to me like using any other amino acid than glycine would essentially remove the ligand from the distogram loss, as the distances considered are between CBs (except for glycine, that uses CAs, and would be compatible with setting the ligand heavy atoms to CAs).

def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
"""Create pseudo beta features."""
is_gly = tf.equal(aatype, residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
pseudo_beta = tf.where(
tf.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
pseudo_beta_mask = tf.where(
is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
pseudo_beta_mask = tf.cast(pseudo_beta_mask, tf.float32)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta

Umol/src/net/model/modules.py

Lines 1244 to 1273 in f7cd2b4

def _distogram_log_loss(logits, bin_edges, batch, num_bins):
"""Log loss of a distogram."""
assert len(logits.shape) == 3
positions = batch['pseudo_beta']
mask = batch['pseudo_beta_mask']
assert positions.shape[-1] == 3
sq_breaks = jnp.square(bin_edges)
dist2 = jnp.sum(
jnp.square(
jnp.expand_dims(positions, axis=-2) -
jnp.expand_dims(positions, axis=-3)),
axis=-1,
keepdims=True)
true_bins = jnp.sum(dist2 > sq_breaks, axis=-1)
errors = softmax_cross_entropy(
labels=jax.nn.one_hot(true_bins, num_bins), logits=logits)
square_mask = jnp.expand_dims(mask, axis=-2) * jnp.expand_dims(mask, axis=-1)
avg_error = (
jnp.sum(errors * square_mask, axis=(-2, -1)) /
(1e-6 + jnp.sum(square_mask, axis=(-2, -1))))
dist2 = dist2[..., 0]
return dict(loss=avg_error, true_dist=jnp.sqrt(1e-6 + dist2))

I think I should be fine for the most part, but would love to have detailed explanations regardless if you find the time.

Thanks again,
Alexander

Hi,

Great 👍

The distogram is predicted in bins mapped from the pair representation. Therefore, the amino acid type doesn't matter as long as the ground truth coordinates (CB for protein) is provided for that loss.

Hope this helps.

Best,

Patrick

I will close this issue now, please email me if you have any further more detailed questions and I will try my best to answer.

I realize the code is not that neat, but I partly blame DeepMind for this and all their nested calls to all the different files.

Best,

Patrick

I will close this issue now, please email me if you have any further more detailed questions and I will try my best to answer.

I've been able to get some results I'm pretty happy with so I should be all set, thanks! Will definitely shoot you an e-mail if I need to.

I realize the code is not that neat, but I partly blame DeepMind for this and all their nested calls to all the different files.

Yup, not the easiest piece of code to work with for sure...

Thanks again,
Alexander