GPU training
zhangn2015 opened this issue · 1 comments
zhangn2015 commented
hi
How can I speed up training on GPU such as VariationalGP?
DanWaxman commented
If you have Jax with CUDA installed, things should run on the GPU automatically. You can check the backend Jax is using with
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
You can install Jax with CUDA via pip
; here are the instructions for version 0.4.2, which BayesNewton currently uses (see e.g. the Jax v0.4.2 README, as it differs a bit from the current version):
pip install -U "jax[cuda]==0.4.2" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jaxlib==0.4.2
You can double check that a jax.numpy
array is on the correct device by calling arr.device()
.