butchland/fastai_xla_extensions

run batch transforms on CPU if dataloader is using a TPU

Closed this issue · 1 comments

Because of the issue of the slow running of batch transforms on the TPU, a proposed workaround is to run the batch transforms on the CPU and move the input to the TPU after the batch transforms have been run.

It should be implemented as a callback (most probably) in the begin_fit so it updates the dataloaders object to enable to it to postpone the to(self.device) call after executing the after_batch dataloader callbacks... (maybe trick the dataloader self.device to use the cpu and then add an after_batch callback at the end to load to the tpu????)

We need to make this a quick fix since its only temporary until the pytorch xla team optimizes the affine calls which should speed them up on the TPUs.

Implemented in this version d0f0c33