Segmentation Fault when Loss Backward CIFAR cls_mdeq_LARGE_reg
Closed this issue · 10 comments
Hi, I encounter Segmentation Fault (core dump) when training cls_mdeq_LARGE_reg.
The bug happens at epoch 61, iteration 79, right before the code: (loss + factor*jac_loss).backward()
.
I'm following this suggestion to trace back the errors using gdb and here is the error:
#4 0x000055555568e989 in PyObject_GetAttrString () at /tmp/build/80754af9/python-split_1628000493704/work/Objects/object.c:846 #5 0x00005555555ce5ab in PyObject_HasAttrString (v=<optimised out>, name=<optimised out>) at /tmp/build/80754af9/python-split_1628000493704/work/Objects/object.c:854 #6 0x00007ffff4ebb42b in hook_name(_object*) () from /home/hieu/anaconda3/envs/deq/lib/python3.8/site-packages/torch/lib/libt orch_python.so #7 0x00007ffff4ebb84e in check_single_result(_object*, _object*, _object*) ()
I guess the hook
function does not like something here?
I'm using:
Python v3.8.11
Pytorch v1.7.1+cu110
CUDA v11.1
RTX 3090 graphic cards.
Hi @HieuPhan33 ,
Thanks for the feedback and question! Are you able to try PyTorch 1.6 and let me know if things work properly? I suspect this is a PyTorch version issue (the backward hook seems problematic with PyTorch 1.7 and above) and will push a fix to the current implementation soon (based on PyTorch's custom backward support). I should be able to do this in the next few days and will let you know!
Hi @jerrybai1995,
Thanks for your reply. Unfortunately, RTX 3090 is incompatible with PyTorch 1.6.
Please let me know when you fix the backward hook problem with torch 1.7.
Appreciate your works!
Hi, I now can train with torch 1.7.0 (instead of 1.7.1)!
Heyo, just chiming in here to say that I'm experiencing a similar segmentation fault. I'm on 1.9 and it occurs when removing a hook during the backward_hook
function.
Hi @tesfaldet,
Could you try the following implementation?
Replacing L453-460, currently
deq/MDEQ-Vision/lib/models/mdeq_core.py
Lines 453 to 460 in c161644
with
z1_cp = z1.clone().detach().requires_grad_()
new_z1_cp = func(z1_cp)
def backward_hook(grad):
result = self.b_solver(lambda y: autograd.grad(new_z1_cp, z1_cp, y, retain_graph=True)[0] + grad, torch.zeros_like(grad),
threshold=b_thres, stop_mode=self.stop_mode, name="backward")
return result['result']
new_z1.register_hook(backward_hook) # Notice that it's new_z1 here, not new_z1_cp!
This should probably resolve the issue with PyTorch 1.9 but on the other hand pays the cost of an additional layer (in order to produce new_z1_cp
). Please let me know if this resolves the issue.
(@HieuPhan33, cc you in case you run into this issue in the future when using PyTorch >1.7.0 😄 ).
I tried this with python 1.6, 1.7, and 1.9 and would experience an out of memory during the backward pass:
def forward(self, x):
# setup
_, c, h, w = x.shape
x0 = tensor2vec(x)
func = lambda y: tensor2vec(self.f(vec2tensor(y, (c, h, w))))
# Forward pass
with torch.no_grad():
x_star = self.solver(func, x0, threshold=30)['result']
if self.training:
# re-engage autograd tape
x_star_new = func(x_star.requires_grad_())
# set up Jacobian-vector product for backward pass
def backward_hook(grad):
# Compute the fixed point of yJ + grad, where J=J_f is the Jacobian of f at z_star
grad_func = lambda y: autograd.grad(x_star_new, x_star, y, retain_graph=True)[0] + grad
new_grad = self.solver(grad_func, torch.zeros_like(grad),
threshold=40)['result']
print('grad, new_grad', grad, new_grad)
return new_grad
if self.hook is not None:
self.hook.remove()
torch.cuda.synchronize()
self.hook = x_star_new.register_hook(backward_hook)
return vec2tensor(x_star_new, (c, h, w))
I'll try what you recommended
@tesfaldet Oh what you posted here will definitely have out-of-memory (OOM) error because whenever the backward pass goes through x_star_new
, PyTorch will run the backward_hook
function (as that's what hook does), which will call the backward pass on x_star_new
again (via the autograd.grad
call), which will again call the backward_hook
function. Therefore there'll be an infinite recursion loop. This is the fundamental reason why I put self.hook.remove()
and torch.cuda.synchronize()
within the backward_hook
function originally.
Ohhhhhhh, I see. So the autogrid.grad
call calls the backward_hook
function again. I completely didn't realize! The reason I moved the self.hook.remove()
block out of the backward_hook
function to begin with was because it was causing a segmentation fault otherwise :( This was on all the versions of PyTorch I listed above. Your suggested change seems to forego removing hooks but I worry if that would slowly eat up memory over time.
I don't think it would lead to a memory leak because python's gc will clean up the old hook once its reference count goes to zero and a new training iteration comes in. The major drawback of the suggested change is that we have to pay some additional speed and memory cost to compute the z1_cp
.