Two slightly different process for Deq
Closed this issue · 2 comments
Hi Shaojie,
I found that there were two slightly different forward-backward process for Deq. One was in Chapter 4: Deep Equilibrium.
class DEQFixedPoint(nn.Module):
def __init__(self, f, solver, **kwargs):
super().__init__()
self.f = f
self.solver = solver
self.kwargs = kwargs
def forward(self, x):
# compute forward pass and re-engage autograd tape
with torch.no_grad():
z, self.forward_res = self.solver(lambda z : self.f(z, x), torch.zeros_like(x), **self.kwargs)
z = self.f(z,x)
# set up Jacobian vector product (without additional forward calls)
z0 = z.clone().detach().requires_grad_()
f0 = self.f(z0,x)
def backward_hook(grad):
g, self.backward_res = self.solver(lambda y : autograd.grad(f0, z0, y, retain_graph=True)[0] + grad,
grad, **self.kwargs)
return g
z.register_hook(backward_hook)
return z
And, the second one was in this repo.
deq/DEQ-Sequence/models/deq_transformer.py
Lines 355 to 380 in c161644
I tried torch.autograd.gradcheck on both method using the exact same process from Chapter 4 on colab.
gradcheck(deq, torch.randn(1,2,3,3).cuda().double().requires_grad_(), check_undefined_grad=False)
Interestingly, only the method 1 works properly. The second method breaks my experiment session.
Here is my experiment code https://colab.research.google.com/drive/19vGpV16nbF5HRRKlFGScO-N1Js3NC4hj#scrollTo=kg2UmSW1x1R3
I also tried it on my workstation. I found that method 2 slowly ate all GPU memory and eventually return this message SystemError: <built-in method run_backward of torch._C._EngineBase object at 0x7f65d099e770> returned NULL without setting an error
.
I think I triggered an infinite loop in backward solver although I already called torch.cuda.synchronize()
in backward_hook function.
In this repo, I do not find similar code related to gradient checking. Moreover, method 2 is used in your Transformer-XL examples. I wander whether this means the memory hunger issue rarely happens in practical cases, like training a transformer.
Thanks :-)
My experiment environment:
workstation:
- python: 3.8
- pytorch: 1.9
- cuda: 11.1
- GPU: Nvidia3090
Google colab:
default environment with GPU.
Hi @SamChen ,
Thanks for the question. The gradcheck fails because gradcheck
works by backpropagating through the same computation graph multiple times (e.g., by adding eps to each entry of the vector output, and then backprop); whereas the self.hook.remove()
already removed the hook upon the first backward call. Therefore, the code is correct for DEQ-Transformer training (where each iteration has exactly ONE backward pass through the DEQ), but is incorrect for repetitive backward passes (which is what gradcheck does).
The memory leak, I believe, is a pytorch-related issue. I'm not entirely sure about the source of this problem but pytorch 1.6 and 1.7 should both work well (i.e., no memory leak). If you encounter the SystemError
and do not want to downgrade pytorch, then you can also use the tutorial implementation--- basically replacing the current L372-380 with:
z1s_copy = z1s.clone().detach().requires_grad_()
new_z1s_copy = self.func(z1s_copy , *func_args) # Spend one more NFE in training forward
def backward_hook(grad):
new_grad = self.b_solver(lambda y: autograd.grad(new_z1s_copy , z1s_copy , y, retain_graph=True)[0] + grad, \
torch.zeros_like(grad), threshold=b_thres)['result']
return new_grad
new_z1s.register_hook(backward_hook)
Of course, this means you have to spend one more NFE in the forward pass of training, which means slightly more memory and computation (which is what the current implementation hoped to avoid). But this should help avoid the memory leak.
Let me know if this helps!
Thanks for the clear explanation. :-)
Your words about gradckeck explains why I saw it calls the DNN function over and over. And, of course, it is not related to the infinite loop
deq/DEQ-Sequence/models/deq_transformer.py
Line 374 in c161644