你好,使用VGG模型运行grafting_cifar时,出现错误
Opened this issue · 3 comments
linsencc commented
models的vgg代码中构造网络时使用了 _make_layers构造卷积层,在运行grafting.py时会出现UnboundLocalError: local variable 'w' referenced before assignment
fxmeng commented
你在函数开始的时候给w一个初始值就好了
linsencc commented
可能不行,代码中根据‘conv’ 字符串查找卷积层,但是这里vgg用了_make_layers将参数层加入到features序列中 ,导致key中没有‘conv’, 只有‘features’
···
def grafting(net, epoch):
……
for i, (key, u) in enumerate(net.state_dict().items()):
if 'conv' in key:
w = round(args.a / np.pi * np.arctan(args.c * (entropy(u) - entropy(checkpoint[key]))) + 0.5, 2)
model[key] = u * w + checkpoint[key] * (1 - w)
net.load_state_dict(model)
···
fxmeng commented
哦哦,那这里需要改一下判断语句了,比如:if len(n.shape)==4