您好,关于grafting.py中的grafting函数有个问题
Closed this issue · 2 comments
tjmannn commented
我看代码的缩进是这样的:
def grafting(net, epoch):
while True:
try:
checkpoint = torch.load('%s/ckpt%d_%d.t7' % (args.s, args.i - 1, epoch))['net']
break
except:
time.sleep(10)
model = collections.OrderedDict()
for i, (key, u) in enumerate(net.state_dict().items()):
if 'conv' in key:
w = round(args.a / np.pi * np.arctan(args.c * (entropy(u) - entropy(checkpoint[key]))) + 0.5, 2)
model[key] = u * w + checkpoint[key] * (1 - w)
net.load_state_dict(model)
这里w是嫁接系数α
所以是所有层都参与嫁接?
但是这些不是卷积层的层,它们的嫁接系数是通过其上面的一个卷积层来计算的?
谢谢!
fxmeng commented
是所有层都参与嫁接,每个卷积层算出来的参数决定他后面的层嫁接的比例。例如BN层每个weight,bias和每个filter具有对应关系,所以用相同的比例是比较合理的做法。