rasbt/LLMs-from-scratch

Edge case: Gradient accumulation

d-kleine opened this issue · 3 comments

About ch06/02_bonus_additional-experiments/additional-experiments.py, I was wondering about this experiment:
Row 13: python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8

As this uses gradient accumulation with accumulation_steps=8, shouldn't in train_classifier_simple() the model weights be updated and gradients reset via

if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
    optimizer.step()  # Update model weights using loss gradients
    optimizer.zero_grad()  # Reset loss gradients from previous batch iteration

to ensure that the optimizer steps are performed correctly after the specified number of accumulation steps?

In contrast, using

if batch_idx % accumulation_steps == 0:
...

would incorrectly trigger an optimizer step on the very first batch (batch_idx == 0). It would also fail to handle the last batch correctly if the number of batches is not divisible by accumulation_steps. Or am I wrong?

That's a great observation. Yes, my implementation was a bit lazy there wrt to the first and last batch.

Alright, thanks! 🙂