How to use the pretrained model?
Closed this issue · 3 comments
I download the pretrained model u provided.
And load it by:
net = resnext50_32x4d()
net.load_state_dict(torch.load('ResNext50_checkpoint_best.pth.tar'))
Then I encounter the error like:
'unexpected key "state_dict" in state_dict'
Would u please give an example of using the pretrained model?
Can you give me your code? Let me have a try.
I think I've solved it:
net = resnext50_32x4d()
checkpoint = torch.load('ResNext50_checkpoint_best.pth.tar')
state_dict=checkpoint['state_dict']
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove module.
new_state_dict[name] = v
print(new_state_dict)
net.load_state_dict(new_state_dict)
data = Variable(torch.ones(1, 3,224, 224))
output = net(data)
print(output)
Yes, you are right!