Cannot load the pretrained models
yaoweilee opened this issue · 3 comments
Hi, I ran into a problem when I tried to load the pretrained resnet-50 model. It seems that the keys in the pre-trained model and keys in the torchvision resnet-50 are not the same. The same problem appears when I tried to load other models listed on the Model Zoo table. Could you please help me with this issue? Thanks.
Here is my code:
import torch, torchvision
model = torchvision.models.resnet50()
checkpoint = torch.load('.user/swav_800ep_pretrain.pth.tar')
model.load_state_dict(checkpoint, strict=False)
when I set strict=False, the model does not load any weights and act like a random initialized model.
when I set strict=True, it will raise error as following:
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict:
"conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var",
......
"layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict:
"module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num_batches_tracked", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean",
......
"module.projection_head.0.weight", "module.projection_head.0.bias", "module.projection_head.1.weight", "module.projection_head.1.bias", "module.projection_head.1.running_mean", "module.projection_head.1.running_var", "module.projection_head.1.num_batches_tracked", "module.projection_head.3.weight", "module.projection_head.3.bias", "module.prototypes.weight".
Hi @yaoweilee
You need to remove the prefixe "module." from the keys of my checkpoints. See
Line 157 in b4aa051
Also the pretrained models have additional keys corresponding to the the projection head and prototype layer. You can either manually delete these or use
strict=False
when loading the model.
Note that all of these is done automatically if you use:
import torch
model = torch.hub.load('facebookresearch/swav', 'resnet50')
Thank you so much for your help. swav is a great work BTW!
@yaoweilee when you try model = torch.hub.load('facebookresearch/swav', 'resnet50')
, does it contain the prototypes for you or the projection head?