Problem with torchscript
Opened this issue · 3 comments
Hello,
I am running into a problem with using torchscript to integrate a trained tensornet model with openmm for dynamics. This is in the newest version of the code as of writing (hash 6694816).
The system
I am running this code on NERSC Perlmutter, which uses A100 GPUs (either 40GB or 80GB). My anaconda environment is as follows. I set this environment up following the documentation available at https://torchmd-net.readthedocs.io/en/latest/installation.html using the install from source instructions:
# packages in environment at /global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net:
#
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 2_kmp_llvm conda-forge
_sysroot_linux-64_curr_repodata_hack 3 h69a702a_14 conda-forge
absl-py 2.1.0 pypi_0 pypi
annotated-types 0.6.0 pyhd8ed1ab_0 conda-forge
aom 3.8.2 h59595ed_0 conda-forge
ase 3.22.1 pyhd8ed1ab_1 conda-forge
astunparse 1.6.3 pyhd8ed1ab_0 conda-forge
atk-1.0 2.38.0 hd4edc92_1 conda-forge
binutils_impl_linux-64 2.40 hf600244_0 conda-forge
binutils_linux-64 2.40 hdade7a5_3 conda-forge
blas 2.121 openblas conda-forge
blas-devel 3.9.0 21_linux64_openblas conda-forge
blinker 1.7.0 pyhd8ed1ab_0 conda-forge
blosc 1.21.5 h0f2a231_0 conda-forge
brotli 1.1.0 hd590300_1 conda-forge
brotli-bin 1.1.0 hd590300_1 conda-forge
brotli-python 1.1.0 py311hb755f60_1 conda-forge
brunsli 0.1 h9c3ff4c_0 conda-forge
bzip2 1.0.8 hd590300_5 conda-forge
c-ares 1.27.0 hd590300_0 conda-forge
c-blosc2 2.13.2 hb4ffafa_0 conda-forge
ca-certificates 2024.2.2 hbcca054_0 conda-forge
cached-property 1.5.2 hd8ed1ab_1 conda-forge
cached_property 1.5.2 pyha770c72_1 conda-forge
cairo 1.18.0 h3faef2a_0 conda-forge
captum 0.6.0 pyhd8ed1ab_0 conda-forge
certifi 2024.2.2 pyhd8ed1ab_0 conda-forge
charls 2.4.2 h59595ed_0 conda-forge
charset-normalizer 3.3.2 pyhd8ed1ab_0 conda-forge
click 8.1.7 unix_pyh707e725_0 conda-forge
colorama 0.4.6 pyhd8ed1ab_0 conda-forge
contourpy 1.2.0 py311h9547e67_0 conda-forge
cuda-cccl 12.0.90 ha770c72_1 conda-forge
cuda-cccl-impl 2.0.1 ha770c72_1 conda-forge
cuda-cccl_linux-64 12.0.90 ha770c72_1 conda-forge
cuda-cudart 12.0.107 h59595ed_1 conda-forge
cuda-cudart-dev 12.0.107 h59595ed_1 conda-forge
cuda-cudart-dev_linux-64 12.0.107 h59595ed_1 conda-forge
cuda-cudart-static 12.0.107 h59595ed_1 conda-forge
cuda-cudart-static_linux-64 12.0.107 h59595ed_1 conda-forge
cuda-driver-dev 12.0.107 hd3aeb46_8 conda-forge
cuda-driver-dev_linux-64 12.0.107 h59595ed_8 conda-forge
cuda-libraries-dev 12.0.0 ha770c72_1 conda-forge
cuda-nvcc 12.0.76 hba56722_12 conda-forge
cuda-nvcc-dev_linux-64 12.0.76 ha770c72_1 conda-forge
cuda-nvcc-impl 12.0.76 h59595ed_1 conda-forge
cuda-nvcc-tools 12.0.76 h59595ed_1 conda-forge
cuda-nvcc_linux-64 12.0.76 hba56722_12 conda-forge
cuda-nvrtc 12.0.76 hd3aeb46_2 conda-forge
cuda-nvrtc-dev 12.0.76 hd3aeb46_2 conda-forge
cuda-nvtx 12.0.76 h59595ed_1 conda-forge
cuda-opencl 12.0.76 h59595ed_0 conda-forge
cuda-opencl-dev 12.0.76 ha770c72_0 conda-forge
cuda-profiler-api 12.0.76 ha770c72_0 conda-forge
cuda-version 12.0 hffde075_3 conda-forge
cudnn 8.9.7.29 h092f7fd_3 conda-forge
cycler 0.12.1 pyhd8ed1ab_0 conda-forge
dav1d 1.2.1 hd590300_0 conda-forge
exceptiongroup 1.2.0 pyhd8ed1ab_2 conda-forge
expat 2.6.2 h59595ed_0 conda-forge
filelock 3.13.1 pyhd8ed1ab_0 conda-forge
flake8 7.0.0 pyhd8ed1ab_0 conda-forge
flask 3.0.2 pyhd8ed1ab_0 conda-forge
font-ttf-dejavu-sans-mono 2.37 hab24e00_0 conda-forge
font-ttf-inconsolata 3.000 h77eed37_0 conda-forge
font-ttf-source-code-pro 2.038 h77eed37_0 conda-forge
font-ttf-ubuntu 0.83 h77eed37_1 conda-forge
fontconfig 2.14.2 h14ed4e7_0 conda-forge
fonts-conda-ecosystem 1 0 conda-forge
fonts-conda-forge 1 0 conda-forge
fonttools 4.50.0 py311h459d7ec_0 conda-forge
freetype 2.12.1 h267a509_2 conda-forge
fribidi 1.0.10 h36c2ea0_0 conda-forge
fsspec 2024.3.1 pyhca7485f_0 conda-forge
gcc 11.4.0 h7dfb3fc_3 conda-forge
gcc_impl_linux-64 11.4.0 h7aa1c59_5 conda-forge
gcc_linux-64 11.4.0 h0f0c6b6_3 conda-forge
gdk-pixbuf 2.42.10 h829c605_5 conda-forge
gettext 0.21.1 h27087fc_0 conda-forge
giflib 5.2.1 h0b41bf4_3 conda-forge
gmp 6.3.0 h59595ed_1 conda-forge
gmpy2 2.1.2 py311h6a5fa03_1 conda-forge
graphite2 1.3.13 h58526e2_1001 conda-forge
graphviz 9.0.0 h78e8752_1 conda-forge
grpcio 1.62.1 pypi_0 pypi
gtk2 2.24.33 h280cfa0_4 conda-forge
gts 0.7.6 h977cf35_4 conda-forge
gxx 11.4.0 h7dfb3fc_3 conda-forge
gxx_impl_linux-64 11.4.0 h7aa1c59_5 conda-forge
gxx_linux-64 11.4.0 h2730b16_3 conda-forge
h5py 3.10.0 nompi_py311hebc2b07_101 conda-forge
harfbuzz 8.3.0 h3d44ed6_0 conda-forge
hdf5 1.14.3 nompi_h4f84152_100 conda-forge
icu 73.2 h59595ed_0 conda-forge
idna 3.6 pyhd8ed1ab_0 conda-forge
imagecodecs 2024.1.1 py311hd0e15ba_2 conda-forge
imageio 2.34.0 pyh4b66e23_0 conda-forge
importlib-metadata 7.1.0 pyha770c72_0 conda-forge
importlib_metadata 7.1.0 hd8ed1ab_0 conda-forge
iniconfig 2.0.0 pyhd8ed1ab_0 conda-forge
isodate 0.6.1 pyhd8ed1ab_0 conda-forge
itsdangerous 2.1.2 pyhd8ed1ab_0 conda-forge
jinja2 3.1.3 pyhd8ed1ab_0 conda-forge
joblib 1.3.2 pyhd8ed1ab_0 conda-forge
jxrlib 1.1 hd590300_3 conda-forge
kernel-headers_linux-64 3.10.0 h4a8ded7_14 conda-forge
keyutils 1.6.1 h166bdaf_0 conda-forge
kiwisolver 1.4.5 py311h9547e67_1 conda-forge
krb5 1.21.2 h659d440_0 conda-forge
lark-parser 0.12.0 pyhd8ed1ab_0 conda-forge
lazy_loader 0.3 pyhd8ed1ab_0 conda-forge
lcms2 2.16 hb7c19ff_0 conda-forge
ld_impl_linux-64 2.40 h41732ed_0 conda-forge
lerc 4.0.0 h27087fc_0 conda-forge
libabseil 20230802.1 cxx17_h59595ed_0 conda-forge
libaec 1.1.3 h59595ed_0 conda-forge
libavif16 1.0.4 hd9d6309_2 conda-forge
libblas 3.9.0 21_linux64_openblas conda-forge
libbrotlicommon 1.1.0 hd590300_1 conda-forge
libbrotlidec 1.1.0 hd590300_1 conda-forge
libbrotlienc 1.1.0 hd590300_1 conda-forge
libcblas 3.9.0 21_linux64_openblas conda-forge
libcublas 12.0.1.189 hd3aeb46_3 conda-forge
libcublas-dev 12.0.1.189 hd3aeb46_3 conda-forge
libcufft 11.0.0.21 hd3aeb46_2 conda-forge
libcufft-dev 11.0.0.21 hd3aeb46_2 conda-forge
libcufile 1.5.0.59 hd3aeb46_1 conda-forge
libcufile-dev 1.5.0.59 hd3aeb46_1 conda-forge
libcurand 10.3.1.50 hd3aeb46_1 conda-forge
libcurand-dev 10.3.1.50 hd3aeb46_1 conda-forge
libcurl 8.6.0 hca28451_0 conda-forge
libcusolver 11.4.2.57 hd3aeb46_2 conda-forge
libcusolver-dev 11.4.2.57 hd3aeb46_2 conda-forge
libcusparse 12.0.0.76 hd3aeb46_2 conda-forge
libcusparse-dev 12.0.0.76 hd3aeb46_2 conda-forge
libdeflate 1.19 hd590300_0 conda-forge
libedit 3.1.20191231 he28a2e2_2 conda-forge
libev 4.33 hd590300_2 conda-forge
libexpat 2.6.2 h59595ed_0 conda-forge
libffi 3.4.2 h7f98852_5 conda-forge
libgcc-devel_linux-64 11.4.0 h922705a_105 conda-forge
libgcc-ng 13.2.0 h807b86a_5 conda-forge
libgd 2.3.3 h119a65a_9 conda-forge
libgfortran-ng 13.2.0 h69a702a_5 conda-forge
libgfortran5 13.2.0 ha4646dd_5 conda-forge
libglib 2.80.0 hf2295e7_1 conda-forge
libgomp 13.2.0 h807b86a_5 conda-forge
libhwloc 2.9.3 default_h554bfaf_1009 conda-forge
libhwy 1.0.7 h00ab1b0_0 conda-forge
libiconv 1.17 hd590300_2 conda-forge
libjpeg-turbo 3.0.0 hd590300_1 conda-forge
libjxl 0.10.1 h5b01ea3_0 conda-forge
liblapack 3.9.0 21_linux64_openblas conda-forge
liblapacke 3.9.0 21_linux64_openblas conda-forge
libllvm14 14.0.6 hcd5def8_4 conda-forge
libmagma 2.7.2 h173bb3b_2 conda-forge
libmagma_sparse 2.7.2 h173bb3b_3 conda-forge
libnghttp2 1.58.0 h47da74e_1 conda-forge
libnpp 12.0.0.30 hd3aeb46_1 conda-forge
libnpp-dev 12.0.0.30 hd3aeb46_1 conda-forge
libnsl 2.0.1 hd590300_0 conda-forge
libnvjitlink 12.0.76 hd3aeb46_2 conda-forge
libnvjitlink-dev 12.0.76 hd3aeb46_2 conda-forge
libnvjpeg 12.0.0.28 h59595ed_1 conda-forge
libnvjpeg-dev 12.0.0.28 ha770c72_1 conda-forge
libopenblas 0.3.26 pthreads_h413a1c8_0 conda-forge
libpng 1.6.43 h2797004_0 conda-forge
libprotobuf 4.25.1 hf27288f_2 conda-forge
librsvg 2.56.3 he3f83f7_1 conda-forge
libsanitizer 11.4.0 h4dcbe23_5 conda-forge
libsqlite 3.45.2 h2797004_0 conda-forge
libssh2 1.11.0 h0841786_0 conda-forge
libstdcxx-devel_linux-64 11.4.0 h922705a_105 conda-forge
libstdcxx-ng 13.2.0 h7e041cc_5 conda-forge
libtiff 4.6.0 ha9c0a0a_2 conda-forge
libtorch 2.1.2 cuda120_h2aa5df7_301 conda-forge
libuuid 2.38.1 h0b41bf4_0 conda-forge
libuv 1.48.0 hd590300_0 conda-forge
libwebp 1.3.2 h658648e_1 conda-forge
libwebp-base 1.3.2 hd590300_0 conda-forge
libxcb 1.15 h0b41bf4_0 conda-forge
libxcrypt 4.4.36 hd590300_1 conda-forge
libxml2 2.12.6 h232c23b_0 conda-forge
libzlib 1.2.13 hd590300_5 conda-forge
libzopfli 1.0.3 h9c3ff4c_0 conda-forge
lightning 2.1.4 pyhd8ed1ab_0 conda-forge
lightning-utilities 0.11.0 pyhd8ed1ab_0 conda-forge
llvm-openmp 18.1.2 h4dfa4b3_0 conda-forge
llvmlite 0.42.0 py311ha6695c7_1 conda-forge
lz4-c 1.9.4 hcb278e6_0 conda-forge
lzo 2.10 h516909a_1000 conda-forge
magma 2.7.2 h51420fd_3 conda-forge
markdown 3.6 pypi_0 pypi
markupsafe 2.1.5 py311h459d7ec_0 conda-forge
matplotlib-base 3.8.3 py311h54ef318_0 conda-forge
mccabe 0.7.0 pyhd8ed1ab_0 conda-forge
mdtraj 1.9.9 py311h90fe790_1 conda-forge
mkl 2023.2.0 h84fe81f_50496 conda-forge
mpc 1.3.1 hfe3b2da_0 conda-forge
mpfr 4.2.1 h9458935_0 conda-forge
mpmath 1.3.0 pyhd8ed1ab_0 conda-forge
munkres 1.1.4 pyh9f0ad1d_0 conda-forge
nccl 2.20.5.1 h3a97aeb_0 conda-forge
ncurses 6.4.20240210 h59595ed_0 conda-forge
networkx 3.2.1 pyhd8ed1ab_0 conda-forge
nnpops 0.6 cuda120py311hcbe25e9_7 conda-forge
numba 0.59.0 py311h96b013e_1 conda-forge
numexpr 2.8.7 py311h812550d_0
numpy 1.26.4 py311h64a7726_0 conda-forge
ocl-icd 2.3.2 hd590300_1 conda-forge
ocl-icd-system 1.0.0 1 conda-forge
openblas 0.3.26 pthreads_h7a3da1a_0 conda-forge
openjpeg 2.5.2 h488ebb8_0 conda-forge
openmm 8.1.1 py311h11a6390_1 conda-forge
openmm-torch 1.4 cuda120py311h478a873_4 conda-forge
openssl 3.2.1 hd590300_1 conda-forge
opt_einsum 3.3.0 pyhc1e730c_2 conda-forge
packaging 24.0 pyhd8ed1ab_0 conda-forge
pandas 2.2.1 py311h320fe9a_0 conda-forge
pango 1.52.1 ha41ecd1_0 conda-forge
patsy 0.5.6 pyhd8ed1ab_0 conda-forge
pcre2 10.43 hcad00b1_0 conda-forge
pillow 10.2.0 py311ha6c5da5_0 conda-forge
pip 24.0 pyhd8ed1ab_0 conda-forge
pixman 0.43.2 h59595ed_0 conda-forge
pluggy 1.4.0 pyhd8ed1ab_0 conda-forge
protobuf 5.26.1 pypi_0 pypi
psutil 5.9.8 py311h459d7ec_0 conda-forge
pthread-stubs 0.4 h36c2ea0_1001 conda-forge
py-cpuinfo 9.0.0 pyhd8ed1ab_0 conda-forge
pycodestyle 2.11.1 pyhd8ed1ab_0 conda-forge
pydantic 2.6.4 pyhd8ed1ab_0 conda-forge
pydantic-core 2.16.3 py311h46250e7_0 conda-forge
pyflakes 3.2.0 pyhd8ed1ab_0 conda-forge
pynndescent 0.5.11 pyhca7485f_0 conda-forge
pyparsing 3.1.2 pyhd8ed1ab_0 conda-forge
pysocks 1.7.1 pyha2e5f31_6 conda-forge
pytables 3.9.2 py311h10c7f7f_1 conda-forge
pytest 8.1.1 pyhd8ed1ab_0 conda-forge
python 3.11.8 hab00c5b_0_cpython conda-forge
python-dateutil 2.9.0 pyhd8ed1ab_0 conda-forge
python-tzdata 2024.1 pyhd8ed1ab_0 conda-forge
python_abi 3.11 4_cp311 conda-forge
pytorch 2.1.2 cuda120_py311h25b6552_301 conda-forge
pytorch-lightning 2.2.1 pyhd8ed1ab_0 conda-forge
pytorch_geometric 2.4.0 pyhd8ed1ab_0 conda-forge
pytz 2024.1 pyhd8ed1ab_0 conda-forge
pywavelets 1.4.1 py311h1f0f07a_1 conda-forge
pyyaml 6.0.1 py311h459d7ec_1 conda-forge
rav1e 0.6.6 he8a937b_2 conda-forge
rdflib 7.0.0 pyhd8ed1ab_0 conda-forge
readline 8.2 h8228510_1 conda-forge
requests 2.31.0 pyhd8ed1ab_0 conda-forge
scikit-image 0.22.0 py311h320fe9a_2 conda-forge
scikit-learn 1.4.1.post1 py311hc009520_0 conda-forge
scipy 1.12.0 py311h64a7726_2 conda-forge
setuptools 65.3.0 pyhd8ed1ab_1 conda-forge
setuptools-scm 6.3.2 pyhd8ed1ab_0 conda-forge
setuptools_scm 6.3.2 hd8ed1ab_0 conda-forge
six 1.16.0 pyh6c4a22f_0 conda-forge
sleef 3.5.1 h9b69904_2 conda-forge
snappy 1.1.10 h9fff704_0 conda-forge
statsmodels 0.14.1 py311h1f0f07a_0 conda-forge
svt-av1 2.0.0 h59595ed_0 conda-forge
sympy 1.12 pypyh9d50eac_103 conda-forge
sysroot_linux-64 2.17 h4a8ded7_14 conda-forge
tbb 2021.11.0 h00ab1b0_1 conda-forge
tensorboard 2.16.2 pypi_0 pypi
tensorboard-data-server 0.7.2 pypi_0 pypi
threadpoolctl 3.4.0 pyhc1e730c_0 conda-forge
tifffile 2024.2.12 pyhd8ed1ab_0 conda-forge
tk 8.6.13 noxft_h4845f30_101 conda-forge
tomli 2.0.1 pyhd8ed1ab_0 conda-forge
torchani 2.2.4 cuda120py311he2766f7_3 conda-forge
torchmd-net 2.1.0 dev_0 <develop>
torchmetrics 1.3.2 pyhd8ed1ab_0 conda-forge
tqdm 4.66.2 pyhd8ed1ab_0 conda-forge
trimesh 4.2.0 pyhd8ed1ab_0 conda-forge
typing-extensions 4.10.0 hd8ed1ab_0 conda-forge
typing_extensions 4.10.0 pyha770c72_0 conda-forge
tzdata 2024a h0c530f3_0 conda-forge
urllib3 2.2.1 pyhd8ed1ab_0 conda-forge
werkzeug 3.0.1 pyhd8ed1ab_0 conda-forge
wheel 0.43.0 pyhd8ed1ab_0 conda-forge
xorg-kbproto 1.0.7 h7f98852_1002 conda-forge
xorg-libice 1.1.1 hd590300_0 conda-forge
xorg-libsm 1.2.4 h7391055_0 conda-forge
xorg-libx11 1.8.7 h8ee46fc_0 conda-forge
xorg-libxau 1.0.11 hd590300_0 conda-forge
xorg-libxdmcp 1.1.3 h7f98852_0 conda-forge
xorg-libxext 1.3.4 h0b41bf4_2 conda-forge
xorg-libxrender 0.9.11 hd590300_0 conda-forge
xorg-renderproto 0.11.1 h7f98852_1002 conda-forge
xorg-xextproto 7.3.0 h0b41bf4_1003 conda-forge
xorg-xproto 7.0.31 h7f98852_1007 conda-forge
xz 5.2.6 h166bdaf_0 conda-forge
yaml 0.2.5 h7f98852_2 conda-forge
zfp 1.0.1 h59595ed_0 conda-forge
zipp 3.17.0 pyhd8ed1ab_0 conda-forge
zlib 1.2.13 hd590300_5 conda-forge
zlib-ng 2.0.7 h0b41bf4_0 conda-forge
zstd 1.5.5 hfc55251_0 conda-forge
Setup
I trained a tensornet model with the ZBL prior using the following configuration file. I included the ZBL prior since I am working with systems containing ions. Training was done on a single A100 GPU, and was restarted from the latest checkpoint after 1000 epochs.
activation: silu
aggr: add
atom_filter: -1
attn_activation: silu
batch_size: 16
box_vecs:
- - 15.223
- 0
- 0
- - 0
- 15.223
- 0
- - 0
- 0
- 15.223
coord_files: null
cutoff_lower: 0.0
cutoff_upper: 5.0
dataset: HDF5
dataset_root: licl_1-4_merged_with_metadata.h5
derivative: true
distance_influence: both
early_stopping_patience: 95
ema_alpha_neg_dy: 1.0
ema_alpha_y: 0.0
embed_files: null
embedding_dimension: 64
energy_files: null
force_files: null
inference_batch_size: 16
load_model: LiCl_0_exp/epoch=999-val_loss=0.0000-test_loss=0.0048.ckpt
log_dir: LiCl_exp
lr: 0.001
lr_factor: 0.9
lr_min: 1.0e-07
lr_patience: 5
lr_warmup_steps: 0
max_num_neighbors: 256
max_z: 100
model: tensornet
neg_dy_weight: 1.0
neighbor_embedding: true
ngpus: -1
num_epochs: 10000
num_heads: 2
num_layers: 0
num_nodes: 1
num_rbf: 32
num_workers: 32
output_model: Scalar
precision: 64
prior_model:
- ZBL:
cutoff_distance: 4.0
max_num_neighbors: 50
rbf_type: expnorm
redirect: true
reduce_op: add
save_interval: 1
seed: 42
splits: licl_generated_splits.npz
standardize: false
tensorboard_use: true
test_interval: 10
test_size: 0.1
train_size: 0.8
trainable_rbf: true
val_size: 0.1
weight_decay: 0.0
y_weight: 0.0
I then used the following script to generate the force module. Since I am using periodic boundary conditions, I use the version ForceModulePBC
:
import torch
import h5py
import numpy as np
from torchmdnet.models.model import load_model
from simtk.openmm import app
import simtk.openmm as mm
import mdtraj as md
from simtk import unit
from sys import stdout
import os
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('pdb_file', type = str, help = 'The pdb file for initializing atom types')
parser.add_argument('interpretation_method', type = str, help = "What package to use for interpreting the pdb file. Either 'openmm' or 'mdtraj'. The pdb file is only used to set atom types for the model")
parser.add_argument('enable_box_vecs', type = str, help = 'Whether or not to enable box vectors as an input to the forward() method of the generated ForceModule()')
parser.add_argument('model_force_units', type = str, help = 'The units of force that the model uses internally')
parser.add_argument('model_energy_units', type = str, help = 'The units of energy that the model uses internally')
def determine_latest_epoch_model(file_lst):
#Filter out the pytorch generated torchscript modules
good_files = [file for file in file_lst if 'generated_mod' not in file]
#Pair the name with the epoch for sorting
pair_lst = [(file, int(file.split("=")[1].split("-")[0])) for file in good_files]
#Sort by the epoch number
sorted_mods = sorted(pair_lst, key = lambda x : x[1])
return sorted_mods[-1][0]
#Internally, openmm uses kj/mol for energy and kj/mol/nm for forces, so these dictionaries contain the
# relevant conversion factors
#Multiplying by the values in the dictionary will take you to the correct units
#Models are assumed to take coordinates in ANGSTROMS
ENERGY_UNIT_CONVERSION = {
'Ha' : 2625.5,
'kJ_mol' : 1
}
FORCE_UNIT_CONVERSION = {
'Ha_A' : (2625.5 / 0.1),
'kJ_mol_A' : (1/0.1)
}
directories_to_process = [
#'011824_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_10.0_num_layers_0_num_heads_2_model_tensornet',
#'011824_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_10.0_num_layers_1_num_heads_2_model_equivariant-transformer',
#'011824_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_1.0_num_layers_0_num_heads_2_model_tensornet',
#'011824_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_1.0_num_layers_1_num_heads_2_model_equivariant-transformer'
#'021624_H2O_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_num_heads_2_model_tensornet',
#'021624_H2O_cutoff_upper_5.0_y_weight_0.5_ema_alpha_y_1.0_neg_dy_weight_0.5_num_layers_0_num_heads_2_model_tensornet',
#'021624_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_10.0_num_layers_0_num_heads_2_model_tensornet',
#'021624_H2O_cutoff_upper_5.0_y_weight_1.0_ema_alpha_y_1.0_neg_dy_weight_1.0_num_layers_0_num_heads_2_model_tensornet'
#'032024_H2O_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_reduce_op_mean_model_tensornet',
#'032024_H2O_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_reduce_op_add_model_tensornet'
#'032824_H2O_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_reduce_op_add_model_tensornet'
#'040124_LiCl_cutoff_upper_5.0_y_weight_0.0_ema_alpha_y_0.0_neg_dy_weight_1.0_num_layers_0_reduce_op_add_model_tensornet'
'040524_LiCl_cutoff_upper_5.0_num_layers_0_reduce_op_add_model_tensornet'
]
class ForceModule(torch.nn.Module):
def __init__(self, atom_types, energy_factor, force_factor):
super().__init__()
self.z = torch.nn.Parameter(atom_types, requires_grad = False)
self.model = model
#store conversion factors used by the model
self.energy_factor = energy_factor
self.force_factor = force_factor
def forward(self, positions):
positions = positions.to(torch.float32)
positions = positions * 10 #nm -> A
energy, force = self.model.forward(self.z, positions)
assert(energy is not None)
assert(force is not None)
force = force * self.force_factor #model units -> kJ/mol/nm
energy = energy * self.energy_factor #model units -> kJ/mol
return energy, force
class ForceModulePBC(torch.nn.Module):
def __init__(self, atom_types, energy_factor, force_factor):
super().__init__()
self.z = torch.nn.Parameter(atom_types, requires_grad = False)
self.model = model
#store conversion factors used by the model
self.energy_factor = energy_factor
self.force_factor = force_factor
def forward(self, positions, box_vectors):
'''
Takes in box vectors here for enforcing periodic boundary conditions, should
be shape (3, 3). For orthorhombic boxes, this will be a diagonal matrix.
'''
positions = positions.to(torch.float32)
positions = positions * 10 #nm -> A
#Multiply box vectors by 10 to ensure unit consistency
energy, force = self.model.forward(z=self.z,
pos=positions,
box = box_vectors * 10)
assert(energy is not None)
assert(force is not None)
force = force * self.force_factor #model units -> kJ/mol/nm
energy = energy * self.energy_factor #model units -> kJ/mol
return energy, force
if __name__ == "__main__":
args = parser.parse_args()
if args.interpretation_method == 'openmm':
pdb = app.PDBFile(args.pdb_file)
atom_types = [atom.element.atomic_number for atom in pdb.topology.atoms()]
elif args.interpretation_method == 'mdtraj':
pdbtraj = md.load(args.pdb_file)
atom_types = [atom.element.atomic_number for atom in pdbtraj.topology.atoms]
FORCE_FACTOR = FORCE_UNIT_CONVERSION[args.model_force_units]
ENERGY_FACTOR = ENERGY_UNIT_CONVERSION[args.model_energy_units]
print(f"Using energy conversion {ENERGY_FACTOR}")
print(f"Using force conversion {FORCE_FACTOR}")
for direc in directories_to_process:
all_model_files = os.listdir(f"models/{direc}")
latest_mod = determine_latest_epoch_model(all_model_files)
print(direc, latest_mod)
print(atom_types)
print(len(atom_types))
#print(atomic_numbers)
model = load_model(f"models/{direc}/{latest_mod}", derivative = True)
#This script is used for generating TorchForce objects used by openmmtorch. If done correctly, the
# trained force should function as a neural network potential within the openmm framework
# First need to wrap the loaded model within a ForceModule class object as required by the
# openmmtorch interface
atom_types = torch.tensor(atom_types, dtype = torch.long)
if args.enable_box_vecs.lower() == 'true':
print("Using PBC-enabled ForceModule")
module = torch.jit.script(ForceModulePBC(atom_types, ENERGY_FACTOR, FORCE_FACTOR))
else:
print("Using non-PBC ForceModule")
module = torch.jit.script(ForceModule(atom_types, ENERGY_FACTOR, FORCE_FACTOR))
module.save(f'models/{direc}/generated_mod.pt')
The error
I ran the following command using this script:
>>> python scripts/torch_force_generator_multi.py pdbs/licl_frame.pdb openmm true Ha_A Ha
It seems the code to generate this torchscript module fails on the call to torch.jit.script()
with the following error:
Using PBC-enabled ForceModule
Traceback (most recent call last):
File "/pscratch/sd/f/frankhu/INXS_openmm_dynamics/scripts/torch_force_generator_multi.py", line 141, in <module>
module = torch.jit.script(ForceModulePBC(atom_types, ENERGY_FACTOR, FORCE_FACTOR))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 1324, in script
return torch.jit._recursive.create_script_module(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
init_fn(script_module)
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
init_fn(script_module)
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
init_fn(script_module)
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_script.py", line 639, in _construct
init_fn(script_module)
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 636, in create_script_module_impl
create_methods_and_properties_from_stubs(
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/jit/_recursive.py", line 469, in create_methods_and_properties_from_stubs
concrete_type._create_methods_and_properties(
RuntimeError:
get_neighbor_pairs_kernel(str strategy, Tensor positions, Tensor batch, Tensor box_vectors, bool use_periodic, float cutoff_lower, float cutoff_upper, int max_num_pairs, bool loop, bool include_transpose) -> ((Tensor, Tensor, Tensor, Tensor)):
Expected a value of type 'float' for argument 'cutoff_lower' but instead found type 'int'.
:
File "/global/cfs/cdirs/m4026/torchmd-net/torchmdnet/models/utils.py", line 263
if batch is None:
batch = torch.zeros(pos.shape[0], dtype=torch.long, device=pos.device)
edge_index, edge_vec, edge_weight, num_pairs = get_neighbor_pairs_kernel(
~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
strategy=self.strategy,
positions=pos,
Other info
Because of the MLP change introduced in commit 6694816, I cannot try to load old models since the keys have been mismatched. However, I did try downgrading my version of the repository to version 74702da and the code above did all work (albeit with an older trained model).
As always, thank you so much for your time, and any help would be greatly appreciated!
Thanks for the thorough issue!
6694816 should not have broken old checkpoints, could you provide one so I can investigate?
For your torchscript issue, try changing this 0 here:
torchmd-net/torchmdnet/priors/zbl.py
Lines 53 to 55 in 6694816
to 0.0
I have seen similar things before with TorchScript.
Thanks as always for the quick response @RaulPPelaez!
Making that change in the ZBL prior did indeed fix the issue and I was able to generate the torchscript module and run some dynamics using it. I will continue testing that to see if I stumble on any other bugs.
As for the old checkpoint loading problem, I have attached a checkpoint file and the yaml file used to run the experiment to this issue as a zip file. This model was trained without ZBL on one A100 GPU. I get the following error if I try to load the model:
>>> from torchmdnet.models.model import load_model
>>> model = load_model("epoch=999-val_loss=0.0000-test_loss=0.0010.ckpt", derivative=True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/global/cfs/cdirs/m4026/torchmd-net/torchmdnet/models/model.py", line 243, in load_model
model.load_state_dict(state_dict)
File "/global/cfs/cdirs/m4026/torchmd-net/.conda/envs/torchmd-net/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2152, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for TorchMD_Net:
Missing key(s) in state_dict: "output_model.output_network.layers.0.weight", "output_model.output_network.layers.0.bias", "output_model.output_network.layers.2.weight", "output_model.output_network.layers.2.bias".
Unexpected key(s) in state_dict: "output_model.output_network.0.weight", "output_model.output_network.0.bias", "output_model.output_network.2.weight", "output_model.output_network.2.bias".
Because the most recent commit at 6694816 involved reformatting some keys of the state dictionary, I thought that this would be related to that.
Thanks for looking into it!
@FranklinHu1, I am able to load your model using #318.