abhishekkrthakur/tez

zero_grad for accumulation_steps = 1 not working as expected

abdur-n-tr opened this issue · 9 comments

As far as I know, in normal execution flow for zero_grad and forward pass, first we zero_gard for each batch and then do the forward pass but I investigated that in code, it is not happening in this way when accumualtion_steps =1 and batch =1, first forward pass executes first without doing zero_grad.

I tried to reproduce it and it is doing the same which I explained above.

image

Also, I think we can fix this by removing condition in the tez.py file on line # 330, 331.

it should be fixed in main branch now. could you please confirm?

This issue still exists in main branch as well as I did not see any fix for it.

hmm... then i might be missing something. could you please share more information/code? i see that the zero grads as fine. or i missed something?

I just run the latest tez code and same issue is still happening.

image

Here, you can see when batch_index = 0, it will first zero_grad and then forward pass will run and it is fine but when batch_index = 1, this condition will not run so forward pass will run without doing zero_grad first as in below snapshot.

image

So, one solution is to either remove if condition on zero_grad in _step() on line # 336 OR remove the self.batch_index == 0 on line # 299 OR you can come with some other fix which you know better.

Hope it helps.

it seems like one of us is confused. when batch_index = 1 or any value greater than zero, zero_grad is happening here:

if self.batch_index > 0:
. This zero_grad is before the forward pass.

when batch_index = 0, _step function does the zero grad. this zero_grad is before the very first forward pass.

I tried to log the forward pass and zero_grad (wherever it is written in code) like this,

image

image

You can point out if I am printing logs in wrong way.
and you will see logs like this,

image

ofcourse zero_grad is happening for batch_index = 1 but after forward pass completes for batch_index = 1.

Also, I just recap myself about zero_grad and actually it has nothing to do with forward pass instead it must happen before backward pass (apologies for mistyping) so I logged backward pass as well but still same issue as you can see below.

zero_grad before backward pass:
https://discuss.pytorch.org/t/should-we-perform-optimizer-zero-grad-before-we-do-forward-pass-or-after-that/14254

image

ohh! thanks for the code. i got it now and I've fixed it in main branch. many thanks for looking deep into the codebase 🙏🏽

Thanks for considering the issue and quick fix. This repo is really great and keeps updating it with awesome stuff!