Double Precision Issues
harrisonzhu508 opened this issue · 1 comments
Hi!
Many thanks for open-sourcing this package.
I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32
is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,
- for the
Periodic
kernel, it is rather sensitive to the matrix operations in_sequential_kf()
and_sequential_rts()
. - Likewise, the same when the lengthscales are too large in the
Matern32
kernel.
However, reverting to float64
by setting config.update("jax_enable_x64", True)
makes everything quite slow, especially when I use objax
neural network modules, due to the fact that doing so puts all arrays into double precision.
Currently, my solution is to set the neural network weights to float32
manually, and convert input arrays into float32
before entering the network and the outputs back into float64
. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.
Software and hardware details:
objax==1.6.0
jax==0.3.13
jaxlib==0.3.10+cuda11.cudnn805
NVIDIA-SMI 460.56 Driver Version: 460.56 CUDA Version: 11.2
GeForce RTX 3090 GPUs
Thanks in advance.
Best,
Harrison
Hi Harrison,
That's great to hear that you've found the package useful, and thanks for sharing your paper - it looks great!
I agree that float64
is generally needed when working with GPs, whether that's using the Markov formulation or the standard one. Unfortunately I've never tried switching between float32
and float64
, and I'm not aware of a more elegent solution to your problem. I'm also not aware of how / whether GPJax solves this issue - perhaps you could ask the authors of that package?
Sorry I couldn't be of more help.
Will