Edge case: Gradient accumulation
d-kleine opened this issue · 3 comments
About ch06/02_bonus_additional-experiments/additional-experiments.py, I was wondering about this experiment:
Row 13: python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8
As this uses gradient accumulation with accumulation_steps=8
, shouldn't in train_classifier_simple()
the model weights be updated and gradients reset via
if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
optimizer.step() # Update model weights using loss gradients
optimizer.zero_grad() # Reset loss gradients from previous batch iteration
to ensure that the optimizer steps are performed correctly after the specified number of accumulation steps?
In contrast, using
if batch_idx % accumulation_steps == 0:
...
would incorrectly trigger an optimizer step on the very first batch (batch_idx == 0
). It would also fail to handle the last batch correctly if the number of batches is not divisible by accumulation_steps
. Or am I wrong?
That's a great observation. Yes, my implementation was a bit lazy there wrt to the first and last batch.
Alright, thanks! 🙂