torchmd/torchmd-net

Double precision

Closed this issue · 2 comments

TorchMD-Net does not really understand anything other than float32.
For production runs this is probably fine, but I believe it would be useful to be able to pass a dtype argument to TorchMD-Net and run the full model in double for testing/development purposes.
For instance, to check gradients using torch.autograd.gradcheck.

This amounts to adding a bunch of dtype=dtype here and there.
What do you think?

raimis commented

Yes, it will be useful for debugging.

Closing, opened #182 to work on this.