Optimize `fit` and `predict` loops for PyTorch Estimators
f4str opened this issue · 0 comments
Is your feature request related to a problem? Please describe.
All PyTorch estimators currently index slice numpy arrays in their loops for the fit
and predict
methods. Some of these will even put the data on GPU prior to batching which uses an unnecessary amount of GPU VRAM. A more optimized approach is to use the PyTorch dataloader and only move the current batch in the GPU.
Describe the solution you'd like
Modify the fit
and predict
methods for all PyTorch estimators to use dataloaders and only move the current batch to the GPU
This will not only optimize the
Describe alternatives you've considered
N/A
Additional context
The same should be done for all TensorFlowV2 estimators. Some estimators are already using the TensorFlow dataset, but others are still doing numpy index slicing. This can either be done in the same PR or a separate PR.