fxmeng/filter-grafting

您好,关于grafting.py中的grafting函数有个问题

Closed this issue · 2 comments

我看代码的缩进是这样的:

def grafting(net, epoch):
    while True:
        try:
            checkpoint = torch.load('%s/ckpt%d_%d.t7' % (args.s, args.i - 1, epoch))['net']
            break
        except:
            time.sleep(10)
    model = collections.OrderedDict()
    for i, (key, u) in enumerate(net.state_dict().items()):
        if 'conv' in key:
            w = round(args.a / np.pi * np.arctan(args.c * (entropy(u) - entropy(checkpoint[key]))) + 0.5, 2)
        model[key] = u * w + checkpoint[key] * (1 - w)
    net.load_state_dict(model)

这里w是嫁接系数α
所以是所有层都参与嫁接?
但是这些不是卷积层的层,它们的嫁接系数是通过其上面的一个卷积层来计算的?
谢谢!

是所有层都参与嫁接,每个卷积层算出来的参数决定他后面的层嫁接的比例。例如BN层每个weight,bias和每个filter具有对应关系,所以用相同的比例是比较合理的做法。

@fxmeng 明白了,感谢!