AaltoML/BayesNewton

GPU training

zhangn2015 opened this issue · 1 comments

hi
How can I speed up training on GPU such as VariationalGP?

If you have Jax with CUDA installed, things should run on the GPU automatically. You can check the backend Jax is using with

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

You can install Jax with CUDA via pip; here are the instructions for version 0.4.2, which BayesNewton currently uses (see e.g. the Jax v0.4.2 README, as it differs a bit from the current version):

 pip install -U "jax[cuda]==0.4.2" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jaxlib==0.4.2

You can double check that a jax.numpy array is on the correct device by calling arr.device().