"Self and source expected to have the same dtype" with compiled model in LAMMPS
Opened this issue · 2 comments
I trained an allegro model using nequip-train and compiled it using nequip-deploy
The model was trained with the dtypes:
default_dtype: float64
model_dtype: float32
allow_tf32: true
I am training/infering on a RTX 4090.
LAMMPS is compiled with Kokkos support, patched with the patch_lammps.sh file from pair allegro.
Everything works fine, until I hit the following error:
Exception: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__/nequip/nn/_grad_output.py", line 32, in forward
_7 = torch.append(_3, data[k])
func0 = self.func
data0 = (func0).forward(data, )
~~~~~~~~~~~~~~ <--- HERE
of = self.of
_8 = [torch.sum(data0[of])]
File "code/__torch__/nequip/nn/_graph_mixin.py", line 28, in AD_sum_backward
input1 = (radial_basis).forward(input0, )
input2 = (spharm).forward(input1, )
input3 = (allegro).forward(input2, )
~~~~~~~~~~~~~~~~ <--- HERE
input4 = (edge_eng).forward(input3, )
input5 = (edge_eng_sum).forward(input4, )
File "code/__torch__/allegro/nn/_allegro.py", line 121, in AD_logsumexp_backward
_28 = torch.unsqueeze(torch.index(cutoff_coeffs, _27), -1)
new_latents0 = torch.mul(_28, new_latents)
latents1 = torch.index_copy(latents, 0, active_edges0, new_latents0)
~~~~~~~~~~~~~~~~ <--- HERE
_29 = annotate(List[Optional[Tensor]], [active_edges0])
weights = (_01).forward(torch.index(latents1, _29), )
I have tried casting the variables in question in allegro_.py and fc_.py to double and float where needed, but eventually I hit a wall where the error I get turns to:
Exception: expected scalar type Double but found Float
Exception raised from check_type at aten/src/ATen/core/TensorMethods.cpp:12 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xb0 (0x71d3caf9d7f0 in /mnt/ssd2/SpiceAllegro/libtorch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfa (0x71d3caf45f7e in /mnt/ssd2/SpiceAllegro/libtorch/lib/libc10.so)
frame #2: <unknown function> + 0x392f399 (0x71d3b552f399 in /mnt/ssd2/SpiceAllegro/libtorch/lib/libtorch_cpu.so)
frame #3: double* at::TensorBase::mutable_data_ptr<double>() const + 0x43 (0x71d3b5530653 in /mnt/ssd2/SpiceAllegro/libtorch/lib/libtorch_cpu.so)
frame #4: at::TensorAccessor<double, 2ul, at::DefaultPtrTraits, long> at::TensorBase::accessor<double, 2ul>() const & + 0x4b (0x598e8f7251eb in ../mylammps/build/lmp)
frame #5: <unknown function> + 0xa19c33 (0x598e8f998c33 in ../mylammps/build/lmp)
frame #6: <unknown function> + 0x1c24b4 (0x598e8f1414b4 in ../mylammps/build/lmp)
frame #7: <unknown function> + 0x1c4103 (0x598e8f143103 in ../mylammps/build/lmp)
frame #8: <unknown function> + 0x17276f (0x598e8f0f176f in ../mylammps/build/lmp)
frame #9: <unknown function> + 0x172ace (0x598e8f0f1ace in ../mylammps/build/lmp)
frame #10: <unknown function> + 0x117041 (0x598e8f096041 in ../mylammps/build/lmp)
frame #11: <unknown function> + 0x29d90 (0x71d353a29d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #12: __libc_start_main + 0x80 (0x71d353a29e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #13: <unknown function> + 0x164975 (0x598e8f0e3975 in ../mylammps/build/lmp)
Hi @OlfwayAdbayIgbay ,
Can you confirm which Allegro pair_style
you are using in your LAMMPS script? And which branch of pair_allegro
you are building, and whether you pulled its latest version?
Thanks.
Hi @Linux-cpp-lisp,
Thanks for the response.
I pulled the most recent version of the pair_allegro main branch.
The pair_style looks like this:
pair_style allegro
pair_coeff * * [my path]/deployed_model_latest_run.pth C H N O S
mass 1 1.0
mass 2 1.0
mass 3 1.0
mass 4 1.0
mass 5 1.0
Before that, if of interest, I am initializing like this:
units metal
atom_style atomic
newton on
thermo 1
# get a box defined before pair_coeff
boundary s s s
read_data structure.data