Trusted-AI/adversarial-robustness-toolbox

Optimize `fit` and `predict` loops for PyTorch Estimators

f4str opened this issue · 0 comments

f4str commented

Is your feature request related to a problem? Please describe.
All PyTorch estimators currently index slice numpy arrays in their loops for the fit and predict methods. Some of these will even put the data on GPU prior to batching which uses an unnecessary amount of GPU VRAM. A more optimized approach is to use the PyTorch dataloader and only move the current batch in the GPU.

Describe the solution you'd like
Modify the fit and predict methods for all PyTorch estimators to use dataloaders and only move the current batch to the GPU

This will not only optimize the

Describe alternatives you've considered
N/A

Additional context
The same should be done for all TensorFlowV2 estimators. Some estimators are already using the TensorFlow dataset, but others are still doing numpy index slicing. This can either be done in the same PR or a separate PR.