最后一层是卷积层,该如何去操作?谢谢
cvJie opened this issue · 0 comments
cvJie commented
您好,网络的最后一层是卷积层,请问如果按照下面这种方式去处理最后一层,有什么问题吗?
class FinalConvLayerObserver(Meltable):
def __init__(self, conv2d):
super(FinalLayerObserver, self).__init__()
assert isinstance(conv2d, nn.Conv2d)
self.conv2d = conv2d
self.in_mask =nn.init.constant_(conv2d.weight, 0).to('cpu')# torch.zeros(conv2d.weight,0).to('cpu')
self.f_hook = conv2d.register_forward_hook(self._forward_hook)
def _forward_hook(self, m, _in, _out):
x = _in[0]
self.in_mask += x.data.abs().cpu().sum(0, keepdim=True).view(-1)
def forward(self, x):
return self.conv2d(x)
def melt(self):
with torch.no_grad():
replacer = nn.conv2d(int((self.in_mask != 0).sum()), self.conv2d.weight).to(
self.conv2d.weight.device)
replacer.weight.set_(self.conv2d.weight[:, self.in_mask != 0])
replacer.bias.set_(self.conv2d.bias)
return replacer