run batch transforms on CPU if dataloader is using a TPU
Closed this issue · 1 comments
butchland commented
Because of the issue of the slow running of batch transforms on the TPU, a proposed workaround is to run the batch transforms on the CPU and move the input to the TPU after the batch transforms have been run.
It should be implemented as a callback (most probably) in the begin_fit
so it updates the dataloaders
object to enable to it to postpone the to(self.device)
call after executing the after_batch
dataloader callbacks... (maybe trick the dataloader self.device to use the cpu and then add an after_batch callback at the end to load to the tpu????)
We need to make this a quick fix since its only temporary until the pytorch xla team optimizes the affine calls which should speed them up on the TPUs.