Gradient caching vs Model dropout
harveyp123 opened this issue · 3 comments
harveyp123 commented
The GC-DPR has two steps
- The first step did a full batch forward without gradient, to get the full batch contrastive learning loss and corresponding embedding gradient.
- The second step conduct mini-batch forward, and assign the embedding gradient, then do backward. The mini-batch will loop through the full batch to computing all gradient and accumulate.
However, during the computation, there might be one issues:
- The backbone model has randomized dropout process, the dropout will make the 1 & 2 to be inconsistent. 1's dropout process will be different from 2, so 1's gradient can not be directly applied to 2. 2's gradient shall be calculated again for every mini-batch. This bug can be fixed using some more sophisticated operation to make sure 1&2 to be consistent.
harveyp123 commented
In short, in the second for loop, for everything minibatch query and passage loss backward, you put the query and passage embedding into the original batch, and calculate the gradient for the current query/passage, so you can make sure the dropout behavior doesn't change your gradient too much.
luyug commented
In our train code, the random states are snapshot using the RandContext class
Lines 53 to 69 in 79e1fe0
harveyp123 commented
Oh, okay, I was using deepspeed + gradient caching, the model is wrapped into a deepspeed defined object, and RandContext doesn't work on my side. But it's good to learn from your code : )