Using TPUs
hsleiman1 opened this issue · 1 comments
hsleiman1 commented
Hello,
If we plan to use TPUs instead of GPUs, is it possible with the current config or shall we use a different configuration?
Thanks
beniz commented
Hi, my understanding from Pytorch/Google TPU doc is that it requires importing XLA and creating a device. So I believe the devic
# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm
and
device = xm.xla_device()
Then change the device here: https://github.com/jolibrain/joliGEN/blob/master/models/base_model.py#L87
It's also certainly needed to block certain calls under the use_cuda
config calls in train.py
and models/base_model.py
.
We can look at it, good feature to have!