RuntimeError: `Trying to backward through the graph a second time` when setting opt_mode to fisher_diag
a1trl9 opened this issue · 0 comments
Hi Yuhang,
Thank you for open sourcing this project.
As noted in the paper that diagonal fisher information matrix is applied to replace the pre-activation Hessian, we tried to set opt_mode
to fisher_diag
instead of mse
for reconstruction. However, a runtime error is thrown:
File "xxxx/quant/data_utils.py", line 184, in __call__
loss.backward()
File "xxxx/lib/python3.6/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "xxxx/lib/python3.6/site-packages/torch/autograd/__init__.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
It seems occuring during backward to save grad:
handle = self.layer.register_backward_hook(self.data_saver)
with torch.enable_grad():
try:
self.model.zero_grad()
inputs = model_input.to(self.device)
self.model.set_quant_state(False, False)
out_fp = self.model(inputs)
quantize_model_till(self.model, self.layer, self.act_quant)
out_q = self.model(inputs)
loss = F.kl_div(F.log_softmax(out_q, dim=1), F.softmax(out_fp, dim=1), reduction='batchmean')
# here....
loss.backward()
except StopForwardException:
pass
As indicated by the error, first backward succeeds but second fails.
We tried to create a very simple network for reproducing and the error keeps showing:
class DummyNet(nn.Module):
def __init__(self):
super(DummyNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 3)
self.conv2 = nn.Conv2d(32, 32, 3, 3)
self.conv3 = nn.Conv2d(32, 1, 3, 3)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
output = F.log_softmax(x, dim=0)
return output
recon_model
function is the same as that in the main_imagenet
file:
def recon_model(model: nn.Module):
"""
Block reconstruction. For the first and last layers, we can only apply layer reconstruction.
"""
for name, module in model.named_children():
if isinstance(module, QuantModule):
if module.ignore_reconstruction is True:
print('Ignore reconstruction of layer {}'.format(name))
continue
else:
layer_reconstruction(qnn, module, **kwargs)
elif isinstance(module, BaseQuantBlock):
if module.ignore_reconstruction is True:
print('Ignore reconstruction of block {}'.format(name))
continue
else:
print('Reconstruction for block {}'.format(name))
block_reconstruction(qnn, module, **kwargs)
else:
recon_model(module)
We are not quite sure why PyTorch complains here as backward
only calls once in a batch... But we also noticed that after calling save_grad_data
, grad would be cached for later loss calculation:
# in block_reconstruction
err = loss_func(out_quant, cur_out, cur_grad)
Is intermediate grad still available at this point since backward has already been called? In our case, even we workaround for the first error inside save_grad_data
, here we would get a same one (i. e. backward twice)
Environment
Ubuntu 16.04 / Python 3.6.8 / PyTorch 1.7.1 / CUDA 10.1
Any advice would be appreciated.