PRBonn/lidar-bonnetal

Stop-gradients in skip-connections

kazuto1011 opened this issue · 2 comments

Thank you for sharing your codes. I found that all backbones call detach() in skip-connections. For example:

skips[os] = x.detach()

Could you tell me where this idea is from? I cannot find the corresponding part in the official SqueezeSeg/SqueezeSegV2.
Besides, I'm concerned that the first detach() in SqueezeSegV2 is not what is expected.

# encoder
skip_in = self.conv1b(x)
x = self.conv1a(x)
# first skip done manually
skips[1] = skip_in.detach()

skip_in is detached and never referenced afterward so that the self.conv1b layer never receives gradients to update themselves.
Here is the quick check I did.

# from squeezesegV2.py
encoder = Backbone(encoder_params)
decoder = Decoder(decoder_params, None)

x = torch.randn(1, 5, 64, 512)
y = decoder(*encoder(x))
y.sum().backward()

for name, p in encoder.named_parameters():
    if p.grad is None:
        print(name, "is None")

The above snippet gives:

conv1b.0.weight is None
conv1b.0.bias is None
conv1b.1.weight is None
conv1b.1.bias is None

oops, this issue is quite old. sorry.

As far as I remember: detaching or not might not make a big difference in performance, however, Andres mentioned that he had better experiences when the whole network is learned (instead of just skipping the middle part), when he adds the detach. But I'm not sure if there is a better explanation.

Thank you for your response. Okay I understand you adopted detach() empirically. I totally thought it might bring inefficiency or instability when stopping gradient flows in skip connections.