Stop-gradients in skip-connections
kazuto1011 opened this issue · 2 comments
Thank you for sharing your codes. I found that all backbones call detach()
in skip-connections. For example:
Could you tell me where this idea is from? I cannot find the corresponding part in the official SqueezeSeg/SqueezeSegV2.
Besides, I'm concerned that the first detach()
in SqueezeSegV2 is not what is expected.
lidar-bonnetal/train/backbones/squeezesegV2.py
Lines 170 to 174 in 5a5f4b1
skip_in
is detached and never referenced afterward so that the self.conv1b
layer never receives gradients to update themselves.
Here is the quick check I did.
# from squeezesegV2.py
encoder = Backbone(encoder_params)
decoder = Decoder(decoder_params, None)
x = torch.randn(1, 5, 64, 512)
y = decoder(*encoder(x))
y.sum().backward()
for name, p in encoder.named_parameters():
if p.grad is None:
print(name, "is None")
The above snippet gives:
conv1b.0.weight is None
conv1b.0.bias is None
conv1b.1.weight is None
conv1b.1.bias is None
oops, this issue is quite old. sorry.
As far as I remember: detaching or not might not make a big difference in performance, however, Andres mentioned that he had better experiences when the whole network is learned (instead of just skipping the middle part), when he adds the detach
. But I'm not sure if there is a better explanation.
Thank you for your response. Okay I understand you adopted detach() empirically. I totally thought it might bring inefficiency or instability when stopping gradient flows in skip connections.