torchmd/torchmd-net

Problem with torchscript

Opened this issue · 3 comments

Hello,

I am running into a problem with using torchscript to integrate a trained tensornet model with openmm for dynamics. This is in the newest version of the code as of writing (hash 6694816).

The system

I am running this code on NERSC Perlmutter, which uses A100 GPUs (either 40GB or 80GB). My anaconda environment is as follows. I set this environment up following the documentation available at https://torchmd-net.readthedocs.io/en/latest/installation.html using the install from source instructions:

# packages in environment at /global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
_sysroot_linux-64_curr_repodata_hack 3                   h69a702a_14    conda-forge
absl-py                   2.1.0                    pypi_0    pypi
annotated-types           0.6.0              pyhd8ed1ab_0    conda-forge
aom                       3.8.2                h59595ed_0    conda-forge
ase                       3.22.1             pyhd8ed1ab_1    conda-forge
astunparse                1.6.3              pyhd8ed1ab_0    conda-forge
atk-1.0                   2.38.0               hd4edc92_1    conda-forge
binutils_impl_linux-64    2.40                 hf600244_0    conda-forge
binutils_linux-64         2.40                 hdade7a5_3    conda-forge
blas                      2.121                  openblas    conda-forge
blas-devel                3.9.0           21_linux64_openblas    conda-forge
blinker                   1.7.0              pyhd8ed1ab_0    conda-forge
blosc                     1.21.5               h0f2a231_0    conda-forge
brotli                    1.1.0                hd590300_1    conda-forge
brotli-bin                1.1.0                hd590300_1    conda-forge
brotli-python             1.1.0           py311hb755f60_1    conda-forge
brunsli                   0.1                  h9c3ff4c_0    conda-forge
bzip2                     1.0.8                hd590300_5    conda-forge
c-ares                    1.27.0               hd590300_0    conda-forge
c-blosc2                  2.13.2               hb4ffafa_0    conda-forge
ca-certificates           2024.2.2             hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
cairo                     1.18.0               h3faef2a_0    conda-forge
captum                    0.6.0              pyhd8ed1ab_0    conda-forge
certifi                   2024.2.2           pyhd8ed1ab_0    conda-forge
charls                    2.4.2                h59595ed_0    conda-forge
charset-normalizer        3.3.2              pyhd8ed1ab_0    conda-forge
click                     8.1.7           unix_pyh707e725_0    conda-forge
colorama                  0.4.6              pyhd8ed1ab_0    conda-forge
contourpy                 1.2.0           py311h9547e67_0    conda-forge
cuda-cccl                 12.0.90              ha770c72_1    conda-forge
cuda-cccl-impl            2.0.1                ha770c72_1    conda-forge
cuda-cccl_linux-64        12.0.90              ha770c72_1    conda-forge
cuda-cudart               12.0.107             h59595ed_1    conda-forge
cuda-cudart-dev           12.0.107             h59595ed_1    conda-forge
cuda-cudart-dev_linux-64  12.0.107             h59595ed_1    conda-forge
cuda-cudart-static        12.0.107             h59595ed_1    conda-forge
cuda-cudart-static_linux-64 12.0.107             h59595ed_1    conda-forge
cuda-driver-dev           12.0.107             hd3aeb46_8    conda-forge
cuda-driver-dev_linux-64  12.0.107             h59595ed_8    conda-forge
cuda-libraries-dev        12.0.0               ha770c72_1    conda-forge
cuda-nvcc                 12.0.76             hba56722_12    conda-forge
cuda-nvcc-dev_linux-64    12.0.76              ha770c72_1    conda-forge
cuda-nvcc-impl            12.0.76              h59595ed_1    conda-forge
cuda-nvcc-tools           12.0.76              h59595ed_1    conda-forge
cuda-nvcc_linux-64        12.0.76             hba56722_12    conda-forge
cuda-nvrtc                12.0.76              hd3aeb46_2    conda-forge
cuda-nvrtc-dev            12.0.76              hd3aeb46_2    conda-forge
cuda-nvtx                 12.0.76              h59595ed_1    conda-forge
cuda-opencl               12.0.76              h59595ed_0    conda-forge
cuda-opencl-dev           12.0.76              ha770c72_0    conda-forge
cuda-profiler-api         12.0.76              ha770c72_0    conda-forge
cuda-version              12.0                 hffde075_3    conda-forge
cudnn                     8.9.7.29             h092f7fd_3    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
dav1d                     1.2.1                hd590300_0    conda-forge
exceptiongroup            1.2.0              pyhd8ed1ab_2    conda-forge
expat                     2.6.2                h59595ed_0    conda-forge
filelock                  3.13.1             pyhd8ed1ab_0    conda-forge
flake8                    7.0.0              pyhd8ed1ab_0    conda-forge
flask                     3.0.2              pyhd8ed1ab_0    conda-forge
font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
font-ttf-ubuntu           0.83                 h77eed37_1    conda-forge
fontconfig                2.14.2               h14ed4e7_0    conda-forge
fonts-conda-ecosystem     1                             0    conda-forge
fonts-conda-forge         1                             0    conda-forge
fonttools                 4.50.0          py311h459d7ec_0    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
fribidi                   1.0.10               h36c2ea0_0    conda-forge
fsspec                    2024.3.1           pyhca7485f_0    conda-forge
gcc                       11.4.0               h7dfb3fc_3    conda-forge
gcc_impl_linux-64         11.4.0               h7aa1c59_5    conda-forge
gcc_linux-64              11.4.0               h0f0c6b6_3    conda-forge
gdk-pixbuf                2.42.10              h829c605_5    conda-forge
gettext                   0.21.1               h27087fc_0    conda-forge
giflib                    5.2.1                h0b41bf4_3    conda-forge
gmp                       6.3.0                h59595ed_1    conda-forge
gmpy2                     2.1.2           py311h6a5fa03_1    conda-forge
graphite2                 1.3.13            h58526e2_1001    conda-forge
graphviz                  9.0.0                h78e8752_1    conda-forge
grpcio                    1.62.1                   pypi_0    pypi
gtk2                      2.24.33              h280cfa0_4    conda-forge
gts                       0.7.6                h977cf35_4    conda-forge
gxx                       11.4.0               h7dfb3fc_3    conda-forge
gxx_impl_linux-64         11.4.0               h7aa1c59_5    conda-forge
gxx_linux-64              11.4.0               h2730b16_3    conda-forge
h5py                      3.10.0          nompi_py311hebc2b07_101    conda-forge
harfbuzz                  8.3.0                h3d44ed6_0    conda-forge
hdf5                      1.14.3          nompi_h4f84152_100    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
idna                      3.6                pyhd8ed1ab_0    conda-forge
imagecodecs               2024.1.1        py311hd0e15ba_2    conda-forge
imageio                   2.34.0             pyh4b66e23_0    conda-forge
importlib-metadata        7.1.0              pyha770c72_0    conda-forge
importlib_metadata        7.1.0                hd8ed1ab_0    conda-forge
iniconfig                 2.0.0              pyhd8ed1ab_0    conda-forge
isodate                   0.6.1              pyhd8ed1ab_0    conda-forge
itsdangerous              2.1.2              pyhd8ed1ab_0    conda-forge
jinja2                    3.1.3              pyhd8ed1ab_0    conda-forge
joblib                    1.3.2              pyhd8ed1ab_0    conda-forge
jxrlib                    1.1                  hd590300_3    conda-forge
kernel-headers_linux-64   3.10.0              h4a8ded7_14    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.5           py311h9547e67_1    conda-forge
krb5                      1.21.2               h659d440_0    conda-forge
lark-parser               0.12.0             pyhd8ed1ab_0    conda-forge
lazy_loader               0.3                pyhd8ed1ab_0    conda-forge
lcms2                     2.16                 hb7c19ff_0    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libabseil                 20230802.1      cxx17_h59595ed_0    conda-forge
libaec                    1.1.3                h59595ed_0    conda-forge
libavif16                 1.0.4                hd9d6309_2    conda-forge
libblas                   3.9.0           21_linux64_openblas    conda-forge
libbrotlicommon           1.1.0                hd590300_1    conda-forge
libbrotlidec              1.1.0                hd590300_1    conda-forge
libbrotlienc              1.1.0                hd590300_1    conda-forge
libcblas                  3.9.0           21_linux64_openblas    conda-forge
libcublas                 12.0.1.189           hd3aeb46_3    conda-forge
libcublas-dev             12.0.1.189           hd3aeb46_3    conda-forge
libcufft                  11.0.0.21            hd3aeb46_2    conda-forge
libcufft-dev              11.0.0.21            hd3aeb46_2    conda-forge
libcufile                 1.5.0.59             hd3aeb46_1    conda-forge
libcufile-dev             1.5.0.59             hd3aeb46_1    conda-forge
libcurand                 10.3.1.50            hd3aeb46_1    conda-forge
libcurand-dev             10.3.1.50            hd3aeb46_1    conda-forge
libcurl                   8.6.0                hca28451_0    conda-forge
libcusolver               11.4.2.57            hd3aeb46_2    conda-forge
libcusolver-dev           11.4.2.57            hd3aeb46_2    conda-forge
libcusparse               12.0.0.76            hd3aeb46_2    conda-forge
libcusparse-dev           12.0.0.76            hd3aeb46_2    conda-forge
libdeflate                1.19                 hd590300_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 hd590300_2    conda-forge
libexpat                  2.6.2                h59595ed_0    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-devel_linux-64     11.4.0             h922705a_105    conda-forge
libgcc-ng                 13.2.0               h807b86a_5    conda-forge
libgd                     2.3.3                h119a65a_9    conda-forge
libgfortran-ng            13.2.0               h69a702a_5    conda-forge
libgfortran5              13.2.0               ha4646dd_5    conda-forge
libglib                   2.80.0               hf2295e7_1    conda-forge
libgomp                   13.2.0               h807b86a_5    conda-forge
libhwloc                  2.9.3           default_h554bfaf_1009    conda-forge
libhwy                    1.0.7                h00ab1b0_0    conda-forge
libiconv                  1.17                 hd590300_2    conda-forge
libjpeg-turbo             3.0.0                hd590300_1    conda-forge
libjxl                    0.10.1               h5b01ea3_0    conda-forge
liblapack                 3.9.0           21_linux64_openblas    conda-forge
liblapacke                3.9.0           21_linux64_openblas    conda-forge
libllvm14                 14.0.6               hcd5def8_4    conda-forge
libmagma                  2.7.2                h173bb3b_2    conda-forge
libmagma_sparse           2.7.2                h173bb3b_3    conda-forge
libnghttp2                1.58.0               h47da74e_1    conda-forge
libnpp                    12.0.0.30            hd3aeb46_1    conda-forge
libnpp-dev                12.0.0.30            hd3aeb46_1    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libnvjitlink              12.0.76              hd3aeb46_2    conda-forge
libnvjitlink-dev          12.0.76              hd3aeb46_2    conda-forge
libnvjpeg                 12.0.0.28            h59595ed_1    conda-forge
libnvjpeg-dev             12.0.0.28            ha770c72_1    conda-forge
libopenblas               0.3.26          pthreads_h413a1c8_0    conda-forge
libpng                    1.6.43               h2797004_0    conda-forge
libprotobuf               4.25.1               hf27288f_2    conda-forge
librsvg                   2.56.3               he3f83f7_1    conda-forge
libsanitizer              11.4.0               h4dcbe23_5    conda-forge
libsqlite                 3.45.2               h2797004_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx-devel_linux-64  11.4.0             h922705a_105    conda-forge
libstdcxx-ng              13.2.0               h7e041cc_5    conda-forge
libtiff                   4.6.0                ha9c0a0a_2    conda-forge
libtorch                  2.1.2           cuda120_h2aa5df7_301    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libuv                     1.48.0               hd590300_0    conda-forge
libwebp                   1.3.2                h658648e_1    conda-forge
libwebp-base              1.3.2                hd590300_0    conda-forge
libxcb                    1.15                 h0b41bf4_0    conda-forge
libxcrypt                 4.4.36               hd590300_1    conda-forge
libxml2                   2.12.6               h232c23b_0    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
libzopfli                 1.0.3                h9c3ff4c_0    conda-forge
lightning                 2.1.4              pyhd8ed1ab_0    conda-forge
lightning-utilities       0.11.0             pyhd8ed1ab_0    conda-forge
llvm-openmp               18.1.2               h4dfa4b3_0    conda-forge
llvmlite                  0.42.0          py311ha6695c7_1    conda-forge
lz4-c                     1.9.4                hcb278e6_0    conda-forge
lzo                       2.10              h516909a_1000    conda-forge
magma                     2.7.2                h51420fd_3    conda-forge
markdown                  3.6                      pypi_0    pypi
markupsafe                2.1.5           py311h459d7ec_0    conda-forge
matplotlib-base           3.8.3           py311h54ef318_0    conda-forge
mccabe                    0.7.0              pyhd8ed1ab_0    conda-forge
mdtraj                    1.9.9           py311h90fe790_1    conda-forge
mkl                       2023.2.0         h84fe81f_50496    conda-forge
mpc                       1.3.1                hfe3b2da_0    conda-forge
mpfr                      4.2.1                h9458935_0    conda-forge
mpmath                    1.3.0              pyhd8ed1ab_0    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
nccl                      2.20.5.1             h3a97aeb_0    conda-forge
ncurses                   6.4.20240210         h59595ed_0    conda-forge
networkx                  3.2.1              pyhd8ed1ab_0    conda-forge
nnpops                    0.6             cuda120py311hcbe25e9_7    conda-forge
numba                     0.59.0          py311h96b013e_1    conda-forge
numexpr                   2.8.7           py311h812550d_0
numpy                     1.26.4          py311h64a7726_0    conda-forge
ocl-icd                   2.3.2                hd590300_1    conda-forge
ocl-icd-system            1.0.0                         1    conda-forge
openblas                  0.3.26          pthreads_h7a3da1a_0    conda-forge
openjpeg                  2.5.2                h488ebb8_0    conda-forge
openmm                    8.1.1           py311h11a6390_1    conda-forge
openmm-torch              1.4             cuda120py311h478a873_4    conda-forge
openssl                   3.2.1                hd590300_1    conda-forge
opt_einsum                3.3.0              pyhc1e730c_2    conda-forge
packaging                 24.0               pyhd8ed1ab_0    conda-forge
pandas                    2.2.1           py311h320fe9a_0    conda-forge
pango                     1.52.1               ha41ecd1_0    conda-forge
patsy                     0.5.6              pyhd8ed1ab_0    conda-forge
pcre2                     10.43                hcad00b1_0    conda-forge
pillow                    10.2.0          py311ha6c5da5_0    conda-forge
pip                       24.0               pyhd8ed1ab_0    conda-forge
pixman                    0.43.2               h59595ed_0    conda-forge
pluggy                    1.4.0              pyhd8ed1ab_0    conda-forge
protobuf                  5.26.1                   pypi_0    pypi
psutil                    5.9.8           py311h459d7ec_0    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
py-cpuinfo                9.0.0              pyhd8ed1ab_0    conda-forge
pycodestyle               2.11.1             pyhd8ed1ab_0    conda-forge
pydantic                  2.6.4              pyhd8ed1ab_0    conda-forge
pydantic-core             2.16.3          py311h46250e7_0    conda-forge
pyflakes                  3.2.0              pyhd8ed1ab_0    conda-forge
pynndescent               0.5.11             pyhca7485f_0    conda-forge
pyparsing                 3.1.2              pyhd8ed1ab_0    conda-forge
pysocks                   1.7.1              pyha2e5f31_6    conda-forge
pytables                  3.9.2           py311h10c7f7f_1    conda-forge
pytest                    8.1.1              pyhd8ed1ab_0    conda-forge
python                    3.11.8          hab00c5b_0_cpython    conda-forge
python-dateutil           2.9.0              pyhd8ed1ab_0    conda-forge
python-tzdata             2024.1             pyhd8ed1ab_0    conda-forge
python_abi                3.11                    4_cp311    conda-forge
pytorch                   2.1.2           cuda120_py311h25b6552_301    conda-forge
pytorch-lightning         2.2.1              pyhd8ed1ab_0    conda-forge
pytorch_geometric         2.4.0              pyhd8ed1ab_0    conda-forge
pytz                      2024.1             pyhd8ed1ab_0    conda-forge
pywavelets                1.4.1           py311h1f0f07a_1    conda-forge
pyyaml                    6.0.1           py311h459d7ec_1    conda-forge
rav1e                     0.6.6                he8a937b_2    conda-forge
rdflib                    7.0.0              pyhd8ed1ab_0    conda-forge
readline                  8.2                  h8228510_1    conda-forge
requests                  2.31.0             pyhd8ed1ab_0    conda-forge
scikit-image              0.22.0          py311h320fe9a_2    conda-forge
scikit-learn              1.4.1.post1     py311hc009520_0    conda-forge
scipy                     1.12.0          py311h64a7726_2    conda-forge
setuptools                65.3.0             pyhd8ed1ab_1    conda-forge
setuptools-scm            6.3.2              pyhd8ed1ab_0    conda-forge
setuptools_scm            6.3.2                hd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sleef                     3.5.1                h9b69904_2    conda-forge
snappy                    1.1.10               h9fff704_0    conda-forge
statsmodels               0.14.1          py311h1f0f07a_0    conda-forge
svt-av1                   2.0.0                h59595ed_0    conda-forge
sympy                     1.12            pypyh9d50eac_103    conda-forge
sysroot_linux-64          2.17                h4a8ded7_14    conda-forge
tbb                       2021.11.0            h00ab1b0_1    conda-forge
tensorboard               2.16.2                   pypi_0    pypi
tensorboard-data-server   0.7.2                    pypi_0    pypi
threadpoolctl             3.4.0              pyhc1e730c_0    conda-forge
tifffile                  2024.2.12          pyhd8ed1ab_0    conda-forge
tk                        8.6.13          noxft_h4845f30_101    conda-forge
tomli                     2.0.1              pyhd8ed1ab_0    conda-forge
torchani                  2.2.4           cuda120py311he2766f7_3    conda-forge
torchmd-net               2.1.0                     dev_0    <develop>
torchmetrics              1.3.2              pyhd8ed1ab_0    conda-forge
tqdm                      4.66.2             pyhd8ed1ab_0    conda-forge
trimesh                   4.2.0              pyhd8ed1ab_0    conda-forge
typing-extensions         4.10.0               hd8ed1ab_0    conda-forge
typing_extensions         4.10.0             pyha770c72_0    conda-forge
tzdata                    2024a                h0c530f3_0    conda-forge
urllib3                   2.2.1              pyhd8ed1ab_0    conda-forge
werkzeug                  3.0.1              pyhd8ed1ab_0    conda-forge
wheel                     0.43.0             pyhd8ed1ab_0    conda-forge
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.1.1                hd590300_0    conda-forge
xorg-libsm                1.2.4                h7391055_0    conda-forge
xorg-libx11               1.8.7                h8ee46fc_0    conda-forge
xorg-libxau               1.0.11               hd590300_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h0b41bf4_2    conda-forge
xorg-libxrender           0.9.11               hd590300_0    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h0b41bf4_1003    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
yaml                      0.2.5                h7f98852_2    conda-forge
zfp                       1.0.1                h59595ed_0    conda-forge
zipp                      3.17.0             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zlib-ng                   2.0.7                h0b41bf4_0    conda-forge
zstd                      1.5.5                hfc55251_0    conda-forge

Setup

I trained a tensornet model with the ZBL prior using the following configuration file. I included the ZBL prior since I am working with systems containing ions. Training was done on a single A100 GPU, and was restarted from the latest checkpoint after 1000 epochs.

activation: silu
aggr: add
atom_filter: -1
attn_activation: silu
batch_size: 16
box_vecs:
- - 15.223
  - 0
  - 0
- - 0
  - 15.223
  - 0
- - 0
  - 0
  - 15.223
coord_files: null
cutoff_lower: 0.0
cutoff_upper: 5.0
dataset: HDF5
dataset_root: licl_1-4_merged_with_metadata.h5
derivative: true
distance_influence: both
early_stopping_patience: 95
ema_alpha_neg_dy: 1.0
ema_alpha_y: 0.0
embed_files: null
embedding_dimension: 64
energy_files: null
force_files: null
inference_batch_size: 16
load_model: LiCl_0_exp/epoch=999-val_loss=0.0000-test_loss=0.0048.ckpt
log_dir: LiCl_exp
lr: 0.001
lr_factor: 0.9
lr_min: 1.0e-07
lr_patience: 5
lr_warmup_steps: 0
max_num_neighbors: 256
max_z: 100
model: tensornet
neg_dy_weight: 1.0
neighbor_embedding: true
ngpus: -1
num_epochs: 10000
num_heads: 2
num_layers: 0
num_nodes: 1
num_rbf: 32
num_workers: 32
output_model: Scalar
precision: 64
prior_model:
- ZBL:
    cutoff_distance: 4.0
    max_num_neighbors: 50
rbf_type: expnorm
redirect: true
reduce_op: add
save_interval: 1
seed: 42
splits: licl_generated_splits.npz
standardize: false
tensorboard_use: true
test_interval: 10
test_size: 0.1
train_size: 0.8
trainable_rbf: true
val_size: 0.1
weight_decay: 0.0
y_weight: 0.0

I then used the following script to generate the force module. Since I am using periodic boundary conditions, I use the version ForceModulePBC:

import torch
import h5py
import numpy as np
from torchmdnet.models.model import load_model
from simtk.openmm import app
import simtk.openmm as mm
import mdtraj as md
from simtk import unit
from sys import stdout
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('pdb_file', type = str, help = 'The pdb file for initializing atom types')
parser.add_argument('interpretation_method', type = str, help = "What package to use for interpreting the pdb file. Either 'openmm' or 'mdtraj'. The pdb file is only used to set atom types for the model")
parser.add_argument('enable_box_vecs', type = str, help = 'Whether or not to enable box vectors as an input to the forward() method of the generated ForceModule()')
parser.add_argument('model_force_units', type = str, help = 'The units of force that the model uses internally')
parser.add_argument('model_energy_units', type = str, help = 'The units of energy that the model uses internally')

def determine_latest_epoch_model(file_lst):
    #Filter out the pytorch generated torchscript modules
    good_files = [file for file in file_lst if 'generated_mod' not in file]
    #Pair the name with the epoch for sorting
    pair_lst = [(file, int(file.split("=")[1].split("-")[0])) for file in good_files]
    #Sort by the epoch number
    sorted_mods = sorted(pair_lst, key = lambda x : x[1])
    return sorted_mods[-1][0]

#Internally, openmm uses kj/mol for energy and kj/mol/nm for forces, so these dictionaries contain the
# relevant conversion factors
#Multiplying by the values in the dictionary will take you to the correct units
#Models are assumed to take coordinates in ANGSTROMS

ENERGY_UNIT_CONVERSION = {
        'Ha'     : 2625.5,
        'kJ_mol' : 1
}

FORCE_UNIT_CONVERSION = {
        'Ha_A'     : (2625.5 / 0.1),
        'kJ_mol_A' : (1/0.1)
}

directories_to_process = [
    #'011824_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_10.0_num_layers_0_num_heads_2_model_tensornet',
    #'011824_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_10.0_num_layers_1_num_heads_2_model_equivariant-transformer',
    #'011824_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_1.0_num_layers_0_num_heads_2_model_tensornet',
    #'011824_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_1.0_num_layers_1_num_heads_2_model_equivariant-transformer'
    #'021624_H2O_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_num_heads_2_model_tensornet',
    #'021624_H2O_cutoff_upper_5.0_y_weight_0.5_ema_alpha_y_1.0_neg_dy_weight_0.5_num_layers_0_num_heads_2_model_tensornet',
    #'021624_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_10.0_num_layers_0_num_heads_2_model_tensornet',
    #'021624_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_1.0_num_layers_0_num_heads_2_model_tensornet'
    #'032024_H2O_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_reduce_op_mean_model_tensornet',
    #'032024_H2O_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_reduce_op_add_model_tensornet'
    #'032824_H2O_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_reduce_op_add_model_tensornet'
    #'040124_LiCl_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_reduce_op_add_model_tensornet'
    '040524_LiCl_cutoff_upper_5.0_num_layers_0_reduce_op_add_model_tensornet'

]

class ForceModule(torch.nn.Module):
    def __init__(self, atom_types, energy_factor, force_factor):
        super().__init__()
        self.z = torch.nn.Parameter(atom_types, requires_grad = False)
        self.model = model
        #store conversion factors used by the model
        self.energy_factor = energy_factor
        self.force_factor = force_factor

    def forward(self, positions):
        positions = positions.to(torch.float32)
        positions = positions * 10 #nm -> A
        energy, force = self.model.forward(self.z, positions)         
        assert(energy is not None)
        assert(force is not None)
        force = force * self.force_factor    #model units -> kJ/mol/nm
        energy = energy * self.energy_factor #model units -> kJ/mol
        return energy, force
            
class ForceModulePBC(torch.nn.Module):
    def __init__(self, atom_types, energy_factor, force_factor):
        super().__init__()
        self.z = torch.nn.Parameter(atom_types, requires_grad = False)
        self.model = model
        #store conversion factors used by the model
        self.energy_factor = energy_factor
        self.force_factor = force_factor

    def forward(self, positions, box_vectors):
        '''
        Takes in box vectors here for enforcing periodic boundary conditions, should 
        be shape (3, 3). For orthorhombic boxes, this will be a diagonal matrix.
        '''
        positions = positions.to(torch.float32)
        positions = positions * 10 #nm -> A
        #Multiply box vectors by 10 to ensure unit consistency
        energy, force = self.model.forward(z=self.z, 
                                           pos=positions,
                                           box = box_vectors * 10)         
        assert(energy is not None)
        assert(force is not None)
        force = force * self.force_factor     #model units -> kJ/mol/nm
        energy = energy * self.energy_factor  #model units -> kJ/mol
        return energy, force


if __name__ == "__main__":
    args = parser.parse_args()
    
    if args.interpretation_method == 'openmm':
        pdb = app.PDBFile(args.pdb_file)
        atom_types = [atom.element.atomic_number for atom in pdb.topology.atoms()] 
    elif args.interpretation_method == 'mdtraj':
        pdbtraj = md.load(args.pdb_file)
        atom_types = [atom.element.atomic_number for atom in pdbtraj.topology.atoms]

    FORCE_FACTOR = FORCE_UNIT_CONVERSION[args.model_force_units]
    ENERGY_FACTOR = ENERGY_UNIT_CONVERSION[args.model_energy_units]

    print(f"Using energy conversion {ENERGY_FACTOR}")
    print(f"Using force conversion {FORCE_FACTOR}")
    
    for direc in directories_to_process:
        all_model_files = os.listdir(f"models/{direc}")
        latest_mod = determine_latest_epoch_model(all_model_files) 
        print(direc, latest_mod)
    
        print(atom_types)
        print(len(atom_types))
        #print(atomic_numbers)
        model = load_model(f"models/{direc}/{latest_mod}", derivative = True)
        
        #This script is used for generating TorchForce objects used by openmmtorch. If done correctly, the 
        #   trained force should function as a neural network potential within the openmm framework
        #   First need to wrap the loaded model within a ForceModule class object as required by the 
        #   openmmtorch interface
        
        atom_types = torch.tensor(atom_types, dtype = torch.long)
        if args.enable_box_vecs.lower() == 'true':
            print("Using PBC-enabled ForceModule")
            module = torch.jit.script(ForceModulePBC(atom_types, ENERGY_FACTOR, FORCE_FACTOR))
        else:
            print("Using non-PBC ForceModule")
            module = torch.jit.script(ForceModule(atom_types, ENERGY_FACTOR, FORCE_FACTOR))
        module.save(f'models/{direc}/generated_mod.pt')

The error

I ran the following command using this script:

>>> python scripts/torch_force_generator_multi.py pdbs/licl_frame.pdb openmm true Ha_A Ha

It seems the code to generate this torchscript module fails on the call to torch.jit.script() with the following error:

Using PBC-enabled ForceModule
Traceback (most recent call last):
  File "/pscratch/sd/f/frankhu/INXS_openmm_dynamics/scripts/torch_force_generator_multi.py", line 141, in <module>
    module = torch.jit.script(ForceModulePBC(atom_types, ENERGY_FACTOR, FORCE_FACTOR))
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 1324, in script
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
    create_methods_and_properties_from_stubs(
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(
RuntimeError:

get_neighbor_pairs_kernel(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, float cutoff_lower, float cutoff_upper, int max_num_pairs, bool loop, bool include_transpose) -> ((Tensor, Tensor, Tensor, Tensor)):
Expected a value of type 'float' for argument 'cutoff_lower' but instead found type 'int'.
:
  File "/global/cfs/cdirs/m4026/torchmd-net/torchmdnet/models/utils.py", line 263
        if batch is None:
            batch = torch.zeros(pos.shape[0], dtype=torch.long, device=pos.device)
        edge_index, edge_vec, edge_weight, num_pairs = get_neighbor_pairs_kernel(
                                                       ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            strategy=self.strategy,
            positions=pos,

Other info

Because of the MLP change introduced in commit 6694816, I cannot try to load old models since the keys have been mismatched. However, I did try downgrading my version of the repository to version 74702da and the code above did all work (albeit with an older trained model).

As always, thank you so much for your time, and any help would be greatly appreciated!

Thanks for the thorough issue!
6694816 should not have broken old checkpoints, could you provide one so I can investigate?
For your torchscript issue, try changing this 0 here:

self.distance = OptimizedDistance(
0, cutoff_distance, max_num_pairs=-max_num_neighbors
)

to 0.0
I have seen similar things before with TorchScript.

Thanks as always for the quick response @RaulPPelaez!

Making that change in the ZBL prior did indeed fix the issue and I was able to generate the torchscript module and run some dynamics using it. I will continue testing that to see if I stumble on any other bugs.

As for the old checkpoint loading problem, I have attached a checkpoint file and the yaml file used to run the experiment to this issue as a zip file. This model was trained without ZBL on one A100 GPU. I get the following error if I try to load the model:

>>> from torchmdnet.models.model import load_model
>>> model = load_model("epoch=999-val_loss=0.0000-test_loss=0.0010.ckpt", derivative=True)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/global/cfs/cdirs/m4026/torchmd-net/torchmdnet/models/model.py", line 243, in load_model
    model.load_state_dict(state_dict)
  File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TorchMD_Net:
        Missing key(s) in state_dict: "output_model.output_network.layers.0.weight", "output_model.output_network.layers.0.bias", "output_model.output_network.layers.2.weight", "output_model.output_network.layers.2.bias".
        Unexpected key(s) in state_dict: "output_model.output_network.0.weight", "output_model.output_network.0.bias", "output_model.output_network.2.weight", "output_model.output_network.2.bias".

Because the most recent commit at 6694816 involved reformatting some keys of the state dictionary, I thought that this would be related to that.

Thanks for looking into it!

tensornet_ckpt.zip

@FranklinHu1, I am able to load your model using #318.