zero_grad for accumulation_steps = 1 not working as expected
abdur-n-tr opened this issue · 9 comments
As far as I know, in normal execution flow for zero_grad and forward pass, first we zero_gard for each batch and then do the forward pass but I investigated that in code, it is not happening in this way when accumualtion_steps =1 and batch =1, first forward pass executes first without doing zero_grad.
I tried to reproduce it and it is doing the same which I explained above.
Also, I think we can fix this by removing condition in the tez.py file on line # 330, 331.
it should be fixed in main branch now. could you please confirm?
This issue still exists in main branch as well as I did not see any fix for it.
hmm... then i might be missing something. could you please share more information/code? i see that the zero grads as fine. or i missed something?
I just run the latest tez code and same issue is still happening.
Here, you can see when batch_index = 0, it will first zero_grad and then forward pass will run and it is fine but when batch_index = 1, this condition will not run so forward pass will run without doing zero_grad first as in below snapshot.
So, one solution is to either remove if condition on zero_grad
in _step()
on line # 336 OR remove the self.batch_index == 0
on line # 299 OR you can come with some other fix which you know better.
Hope it helps.
it seems like one of us is confused. when batch_index = 1 or any value greater than zero, zero_grad is happening here:
Line 336 in fd2d85e
when batch_index = 0, _step
function does the zero grad. this zero_grad is before the very first forward pass.
Also, I just recap myself about zero_grad and actually it has nothing to do with forward pass instead it must happen before backward pass (apologies for mistyping) so I logged backward pass as well but still same issue as you can see below.
zero_grad before backward pass:
https://discuss.pytorch.org/t/should-we-perform-optimizer-zero-grad-before-we-do-forward-pass-or-after-that/14254
ohh! thanks for the code. i got it now and I've fixed it in main branch. many thanks for looking deep into the codebase 🙏🏽
Thanks for considering the issue and quick fix. This repo is really great and keeps updating it with awesome stuff!