butchland/fastai_xla_extensions

batch transforms for vision are slow

Opened this issue · 4 comments

Confirming that batch transforms are slow
Same notebook without batch tfms - each epoch runs at 1:34 to 2:25 mins
For exact same notebook with batch tfms

Focusing on affine transforms (zoom, warp, rotate) + random resize crop GPU - they seem to cause much of the slowdown.
Normalize and lighting (contrast and brightness) transforms dont seem to slow it down.

Will start narrowing where the slowdown is and do some profiling on the specific tensor operations where its slow.

Will monitoring this issue filed with the pytorch-xla team as resolving it requires an update to pytorch-xla itself.

Update: as of 2020/12/14, using updated Pytorch 1.7 XLA and latest fastai (2.1.8) and fastai_xla_extensions (0.0.4) packages, training with batch transforms is still slower than training without batch transforms.

Partially mitigated by this enhancement: #11

Leaving it open awaiting final action for additional lowerings in Pytorch XLA to support batch transforms