facebookresearch/swav

Cannot load the pretrained models

yaoweilee opened this issue · 3 comments

Hi, I ran into a problem when I tried to load the pretrained resnet-50 model. It seems that the keys in the pre-trained model and keys in the torchvision resnet-50 are not the same. The same problem appears when I tried to load other models listed on the Model Zoo table. Could you please help me with this issue? Thanks.

Here is my code:
import torch, torchvision
model = torchvision.models.resnet50()
checkpoint = torch.load('.user/swav_800ep_pretrain.pth.tar')
model.load_state_dict(checkpoint, strict=False)

when I set strict=False, the model does not load any weights and act like a random initialized model.
when I set strict=True, it will raise error as following:

RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict:
"conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var",
......
"layer4.2.bn2.running_mean", "layer4.2.bn2.running_var", "layer4.2.conv3.weight", "layer4.2.bn3.weight", "layer4.2.bn3.bias", "layer4.2.bn3.running_mean", "layer4.2.bn3.running_var", "fc.weight", "fc.bias".
Unexpected key(s) in state_dict:
"module.conv1.weight", "module.bn1.weight", "module.bn1.bias", "module.bn1.running_mean", "module.bn1.running_var", "module.bn1.num_batches_tracked", "module.layer1.0.conv1.weight", "module.layer1.0.bn1.weight", "module.layer1.0.bn1.bias", "module.layer1.0.bn1.running_mean",
......
"module.projection_head.0.weight", "module.projection_head.0.bias", "module.projection_head.1.weight", "module.projection_head.1.bias", "module.projection_head.1.running_mean", "module.projection_head.1.running_var", "module.projection_head.1.num_batches_tracked", "module.projection_head.3.weight", "module.projection_head.3.bias", "module.prototypes.weight".

Hi @yaoweilee

You need to remove the prefixe "module." from the keys of my checkpoints. See

state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
for an example of how to do that.
Also the pretrained models have additional keys corresponding to the the projection head and prototype layer. You can either manually delete these or use strict=False when loading the model.

Note that all of these is done automatically if you use:

import torch
model = torch.hub.load('facebookresearch/swav', 'resnet50')

Thank you so much for your help. swav is a great work BTW!

@yaoweilee when you try model = torch.hub.load('facebookresearch/swav', 'resnet50'), does it contain the prototypes for you or the projection head?