youzhonghui/gate-decorator-pruning

最后一层是卷积层,该如何去操作?谢谢

cvJie opened this issue · 0 comments

cvJie commented

您好,网络的最后一层是卷积层,请问如果按照下面这种方式去处理最后一层,有什么问题吗?
class FinalConvLayerObserver(Meltable):

def __init__(self, conv2d):
    super(FinalLayerObserver, self).__init__()
    assert isinstance(conv2d, nn.Conv2d)
    self.conv2d = conv2d

    self.in_mask =nn.init.constant_(conv2d.weight, 0).to('cpu')# torch.zeros(conv2d.weight,0).to('cpu')
    self.f_hook = conv2d.register_forward_hook(self._forward_hook)


def _forward_hook(self, m, _in, _out):
    x = _in[0]
    self.in_mask += x.data.abs().cpu().sum(0, keepdim=True).view(-1)

def forward(self, x):
    return self.conv2d(x)

def melt(self):
    with torch.no_grad():
        replacer = nn.conv2d(int((self.in_mask != 0).sum()), self.conv2d.weight).to(
            self.conv2d.weight.device)

        replacer.weight.set_(self.conv2d.weight[:, self.in_mask != 0])

        replacer.bias.set_(self.conv2d.bias)
    return replacer