AaltoML/BayesNewton

Double Precision Issues

harrisonzhu508 opened this issue · 1 comments

Hi!

Many thanks for open-sourcing this package.

I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,

  • for the Periodic kernel, it is rather sensitive to the matrix operations in _sequential_kf() and _sequential_rts().
  • Likewise, the same when the lengthscales are too large in the Matern32 kernel.

However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.

Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.

Software and hardware details:

objax==1.6.0
jax==0.3.13
jaxlib==0.3.10+cuda11.cudnn805
NVIDIA-SMI 460.56       Driver Version: 460.56       CUDA Version: 11.2
GeForce RTX 3090 GPUs

Thanks in advance.

Best,
Harrison

Hi Harrison,

That's great to hear that you've found the package useful, and thanks for sharing your paper - it looks great!

I agree that float64 is generally needed when working with GPs, whether that's using the Markov formulation or the standard one. Unfortunately I've never tried switching between float32 and float64, and I'm not aware of a more elegent solution to your problem. I'm also not aware of how / whether GPJax solves this issue - perhaps you could ask the authors of that package?

Sorry I couldn't be of more help.

Will