openmm/openmm-torch

Issue with TorchForce from TorchMD-Net model

eva-not opened this issue · 9 comments

Posting the same issue in the torchmd-net repo as it might be a joint issue.

I am trying to train a NNP on chignolin trajectories in explicit solvent with TorchMD-Net, which I then plan to use in OpenMM as a TorchForce in order to run implicit solvent simulations with hopefully better accuracy. The idea is similar to https://www.nature.com/articles/s41467-023-41343-1 but without the coarse graining. To train the NNP, I have extracted the coordinates and forces of all protein atoms from my explicit solvent trajectory as npy arrays, and also the partial charges of the protein atoms to use as embeddings, again as a npy array (files too large to be attached).

I trained the NNP following the tutorials from https://github.com/torchmd/torchmd-protein-thermodynamics. The training config yaml file is attached. The batch size was set to 16 as batch sizes higher than that caused GPU memory issues. Once training was done, I selected one of the saved pytorch .cpkt checkpoints and converted it to a pytorch .pt model with:

from torchmdnet.models.model import load_model
model = load_model('train_light/epoch=93-val_loss=8.5159-test_loss=2.0015.ckpt')
torch.jit.script(model).save('test.pt')
The .cpkt and .pt files are attached.

I then attempt to use this model to run MD simulations in OpenMM:
from openmm.app import *
from openmm import *
from openmm.unit import *
from sys import stdout
import numpy as np
from openmmtorch import TorchForce

prmtop = AmberPrmtopFile('chignolin_dry.prmtop')
inpcrd = AmberInpcrdFile('chignolin_dry.inpcrd')
system = prmtop.createSystem(nonbondedMethod=NoCutoff)

.prmtop and .inpcrd attached.

torch_force = TorchForce('test.pt')
system.addForce(torch_force)
integrator = LangevinMiddleIntegrator(298.15*kelvin, 1/picosecond, 0.002*picoseconds)
platform = Platform.getPlatformByName('CUDA')
simulation = Simulation(prmtop.topology, system, integrator, platform)
simulation.context.setPositions(inpcrd.positions)
simulation.minimizeEnergy()
simulation.reporters.append(DCDReporter('output.dcd', 1000))
simulation.reporters.append(StateDataReporter(stdout, 1000, step=True, potentialEnergy=True, temperature=True))
simulation.step(100000)

but:
image

The full OpenMM exception line is:
OpenMMException: forward() is missing value for argument 'pos'. Declaration: forward(__torch__.torchmdnet.models.model.___torch_mangle_0.TorchMD_Net self, Tensor z, Tensor pos, Tensor? batch=None, Tensor? q=None, Tensor? s=None, Dict(str, Tensor)? extra_args=None) -> ((Tensor, Tensor?))

Any help would be greatly appreciated!
files.zip

Issue in torchmd-net repo: torchmd/torchmd-net#265

Hi! I think this one is an issue for openmm-torch.

TorchForce expects the module to have a specific set of inputs/outputs. See here.
These are a subset from what TorchMD-Net models have.

To fix that we have to create a Module that translates from TorchMD-Net to TorchForce, in this case simply:

class ForceModule(torch.nn.Module):
    def __init__(self, z):
        super(ForceModule, self).__init__()
        self.model = torch.jit.load('test.pt')
        self.z = z

    def forward(self, positions):
        y, neg_dy = self.model(self.z, positions)
        return y, neg_dy

Where z, the atomic numbers I obtained like this:

prmtop = AmberPrmtopFile('chignolin_dry.prmtop')
inpcrd = AmberInpcrdFile('chignolin_dry.inpcrd')
system = prmtop.createSystem(nonbondedMethod=NoCutoff)
#Get atomic numbers
z = []
for atom in prmtop.topology.atoms():
    z.append(atom.element.atomic_number)
z = torch.tensor(z, dtype=torch.long)

This is the complete example for convenience:

from openmm.app import *
from openmm import *
from openmm.unit import *
from sys import stdout
import numpy as np
import torchmdnet.models.torchmd_gn
from openmmtorch import TorchForce

import torch

class ForceModule(torch.nn.Module):
    def __init__(self, z):
        super(ForceModule, self).__init__()
        self.model = torch.jit.load('test.pt')
        self.z = z

    def forward(self, positions):
        y, neg_dy = self.model(self.z, positions)
        return y, neg_dy

prmtop = AmberPrmtopFile('chignolin_dry.prmtop')
inpcrd = AmberInpcrdFile('chignolin_dry.inpcrd')
system = prmtop.createSystem(nonbondedMethod=NoCutoff)
#Get atomic numbers
z = []
for atom in prmtop.topology.atoms():
    z.append(atom.element.atomic_number)
z = torch.tensor(z, dtype=torch.long)
module = torch.jit.script(ForceModule(z))
torch_force = TorchForce(module)
torch_force.setOutputsForces(True)
system.addForce(torch_force)
integrator = LangevinMiddleIntegrator(298.15*kelvin, 1/picosecond, 0.002*picoseconds)
platform = Platform.getPlatformByName('CPU')
simulation = Simulation(prmtop.topology, system, integrator, platform)
simulation.context.setPositions(inpcrd.positions)
simulation.minimizeEnergy()
simulation.reporters.append(DCDReporter('output.dcd', 1000))
simulation.reporters.append(StateDataReporter(stdout, 1000, step=True, potentialEnergy=True, temperature=True))
simulation.step(100000)

Note that when I run it this example fails for me with the following message:

openmm.OpenMMException: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torchmdnet/models/utils.py", line 53, in forward
        _4 = torch.format(_0, torch.select(num_pairs, 0, 0), max_pairs)
        _5 = torch.add("AssertionError: ", _4)
        ops.prim.RaiseException(_5)
        ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
      else:
        pass

Traceback of TorchScript, original code (most recent call last):
  File "/home/eva/anaconda3/envs/openmm-torch/lib/python3.11/site-packages/torchmdnet/models/utils.py", line 257, in forward
        if self.check_errors:
            if num_pairs[0] > max_pairs:
                raise AssertionError(
                ~~~~~~~~~~~~~~~~~~~~~
                    "Found num_pairs({}) > max_num_pairs({})".format(
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
                        num_pairs[0], max_pairs
                        ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                    )
                )
RuntimeError: AssertionError: Found num_pairs(27390
[ CPUIntType{} ]) > max_num_pairs(21248)

This might indicate that simply the maximum number of neighbors is too small or that the simulation is unstable, causing some particle to have too many neighbors.

As a side note, the error on my machine says "/home/eva/". I have no idea why, it must be some quirk of the inner workings of TorchScript...

Your ForceModule should also do any necessary unit conversions. The positions passed to it are in nm, and it expects the energy to be returned in kJ/mol and the forces to be returned in kJ/mol/nm. If your model uses different units, you can convert them.

See the README for information about other options you can use in your model.

Thanks! I modified this slightly to pass the data to the GPU:

class ForceModule(torch.nn.Module):
    def __init__(self, z):
        super(ForceModule, self).__init__()
        self.model = torch.jit.load('test.pt').cuda()
        self.z = z.cuda()

    def forward(self, positions):
        positions = positions.cuda()
        y, neg_dy = self.model(self.z, positions)
        return y, neg_dy

I also get the error about the max number of pairs now. I used 128 as the max number of neighbours when training the model, but I took that from the coarse graining tutorials where chignolin only has 10 CA atoms. Atomistic chignolin has 166 atoms (so 128x166 = 21,248 max numbers of pairs in the Runtime Error). I will re-train with more neighbours and see what happens, thanks!

All the calls to .cuda() shouldn't be necessary. It will already make sure that everything is on the right device.

Keep into account the units as Peter mentioned too!
Maybe it will be a good idea to provide a more streamlined connection between TorchMD-Net and OpenMM-Torch. I think it makes more sense for that connection to come from TMDNet, which perhaps could provide a OpenMMForceModule class that takes a checkpoint directly and provides a TorchForce.

I am going to close this now, feel free to reopen or open a new issue if you run into more trouble!

Keep into account the units as Peter mentioned too! Maybe it will be a good idea to provide a more streamlined connection between TorchMD-Net and OpenMM-Torch. I think it makes more sense for that connection to come from TMDNet, which perhaps could provide a OpenMMForceModule class that takes a checkpoint directly and provides a TorchForce.

@RaulPPelaez I believe this is the intended purpose of openmm-ml https://github.com/openmm/openmm-ml

The docs say " you can set up a simulation that uses a standard, pretrained model to represent some or all of the interactions in a system." But this will not provide any standard pretained model, rather allow an user to use its own pretained ones.
Not sure if this is the scope of OpenMM-ML. In any case via OpenMM-Torch the module written in TorchMD-Net could be copy-pasted to ML.