locuslab/deq

Two slightly different process for Deq

Closed this issue · 2 comments

Hi Shaojie,

I found that there were two slightly different forward-backward process for Deq. One was in Chapter 4: Deep Equilibrium.

class DEQFixedPoint(nn.Module):
    def __init__(self, f, solver, **kwargs):
        super().__init__()
        self.f = f
        self.solver = solver
        self.kwargs = kwargs
        
    def forward(self, x):
        # compute forward pass and re-engage autograd tape
        with torch.no_grad():
            z, self.forward_res = self.solver(lambda z : self.f(z, x), torch.zeros_like(x), **self.kwargs)
        z = self.f(z,x)
        
        # set up Jacobian vector product (without additional forward calls)
        z0 = z.clone().detach().requires_grad_()
        f0 = self.f(z0,x)
        def backward_hook(grad):
            g, self.backward_res = self.solver(lambda y : autograd.grad(f0, z0, y, retain_graph=True)[0] + grad,
                                               grad, **self.kwargs)
            return g
                
        z.register_hook(backward_hook)
        return z

And, the second one was in this repo.

with torch.no_grad():
result = self.f_solver(lambda z: self.func(z, *func_args), z1s, threshold=f_thres, stop_mode=self.stop_mode)
z1s = result['result']
new_z1s = z1s
if (not self.training) and spectral_radius_mode:
with torch.enable_grad():
z1s.requires_grad_()
new_z1s = self.func(z1s, *func_args)
_, sradius = power_method(new_z1s, z1s, n_iters=150)
if self.training:
z1s.requires_grad_()
new_z1s = self.func(z1s, *func_args)
if compute_jac_loss:
jac_loss = jac_loss_estimate(new_z1s, z1s, vecs=1)
def backward_hook(grad):
if self.hook is not None:
# To avoid infinite loop
self.hook.remove()
torch.cuda.synchronize()
new_grad = self.b_solver(lambda y: autograd.grad(new_z1s, z1s, y, retain_graph=True)[0] + grad, \
torch.zeros_like(grad), threshold=b_thres)['result']
return new_grad
self.hook = new_z1s.register_hook(backward_hook)

I tried torch.autograd.gradcheck on both method using the exact same process from Chapter 4 on colab.
gradcheck(deq, torch.randn(1,2,3,3).cuda().double().requires_grad_(), check_undefined_grad=False)

Interestingly, only the method 1 works properly. The second method breaks my experiment session.
Here is my experiment code https://colab.research.google.com/drive/19vGpV16nbF5HRRKlFGScO-N1Js3NC4hj#scrollTo=kg2UmSW1x1R3

I also tried it on my workstation. I found that method 2 slowly ate all GPU memory and eventually return this message SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f65d099e770> returned NULL without setting an error.
I think I triggered an infinite loop in backward solver although I already called torch.cuda.synchronize() in backward_hook function.

In this repo, I do not find similar code related to gradient checking. Moreover, method 2 is used in your Transformer-XL examples. I wander whether this means the memory hunger issue rarely happens in practical cases, like training a transformer.

Thanks :-)

My experiment environment:
workstation:

  • python: 3.8
  • pytorch: 1.9
  • cuda: 11.1
  • GPU: Nvidia3090

Google colab:
default environment with GPU.

Hi @SamChen ,

Thanks for the question. The gradcheck fails because gradcheck works by backpropagating through the same computation graph multiple times (e.g., by adding eps to each entry of the vector output, and then backprop); whereas the self.hook.remove() already removed the hook upon the first backward call. Therefore, the code is correct for DEQ-Transformer training (where each iteration has exactly ONE backward pass through the DEQ), but is incorrect for repetitive backward passes (which is what gradcheck does).

The memory leak, I believe, is a pytorch-related issue. I'm not entirely sure about the source of this problem but pytorch 1.6 and 1.7 should both work well (i.e., no memory leak). If you encounter the SystemError and do not want to downgrade pytorch, then you can also use the tutorial implementation--- basically replacing the current L372-380 with:

z1s_copy = z1s.clone().detach().requires_grad_()
new_z1s_copy = self.func(z1s_copy , *func_args)      # Spend one more NFE in training forward
def backward_hook(grad): 
    new_grad = self.b_solver(lambda y: autograd.grad(new_z1s_copy , z1s_copy , y, retain_graph=True)[0] + grad, \ 
                                  torch.zeros_like(grad), threshold=b_thres)['result'] 
    return new_grad 
new_z1s.register_hook(backward_hook) 

Of course, this means you have to spend one more NFE in the forward pass of training, which means slightly more memory and computation (which is what the current implementation hoped to avoid). But this should help avoid the memory leak.

Let me know if this helps!

Thanks for the clear explanation. :-)
Your words about gradckeck explains why I saw it calls the DNN function over and over. And, of course, it is not related to the infinite loop

# To avoid infinite loop