luyug/GC-DPR

Gradient caching vs Model dropout

harveyp123 opened this issue · 3 comments

The GC-DPR has two steps

  1. The first step did a full batch forward without gradient, to get the full batch contrastive learning loss and corresponding embedding gradient.
  2. The second step conduct mini-batch forward, and assign the embedding gradient, then do backward. The mini-batch will loop through the full batch to computing all gradient and accumulate.

However, during the computation, there might be one issues:

  1. The backbone model has randomized dropout process, the dropout will make the 1 & 2 to be inconsistent. 1's dropout process will be different from 2, so 1's gradient can not be directly applied to 2. 2's gradient shall be calculated again for every mini-batch. This bug can be fixed using some more sophisticated operation to make sure 1&2 to be consistent.

In short, in the second for loop, for everything minibatch query and passage loss backward, you put the query and passage embedding into the original batch, and calculate the gradient for the current query/passage, so you can make sure the dropout behavior doesn't change your gradient too much.

In our train code, the random states are snapshot using the RandContext class

class RandContext:
def __init__(self, *tensors):
self.fwd_cpu_state = torch.get_rng_state()
self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
def __enter__(self):
self._fork = torch.random.fork_rng(
devices=self.fwd_gpu_devices,
enabled=True
)
self._fork.__enter__()
torch.set_rng_state(self.fwd_cpu_state)
set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)
def __exit__(self, exc_type, exc_val, exc_tb):
self._fork.__exit__(exc_type, exc_val, exc_tb)
self._fork = None
in the first fwd and restored at the beggining of the 2nd, so what you described shouldn't be a problem.

Oh, okay, I was using deepspeed + gradient caching, the model is wrapped into a deepspeed defined object, and RandContext doesn't work on my side. But it's good to learn from your code : )