jolibrain/joliGEN

Using TPUs

hsleiman1 opened this issue · 1 comments

Hello,

If we plan to use TPUs instead of GPUs, is it possible with the current config or shall we use a different configuration?

Thanks

beniz commented

Hi, my understanding from Pytorch/Google TPU doc is that it requires importing XLA and creating a device. So I believe the devic

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm

and

device = xm.xla_device()

Then change the device here: https://github.com/jolibrain/joliGEN/blob/master/models/base_model.py#L87
It's also certainly needed to block certain calls under the use_cuda config calls in train.py and models/base_model.py.

We can look at it, good feature to have!