facebookresearch/swav

question of loading pre-trained model in eval_linear.py and eval_semisup.py

zhawhjw opened this issue · 0 comments

Hi, I am using eval_linear.py and eval_semisup.py to do tasks on my own dataset. When I try to load pre-trained model from here The eval_semisup.py gives following message:

INFO - 08/28/22 11:48:22 - 0:00:00 - Building data done with 8765 images loaded  
INFO - 08/28/22 11:48:22 - 0:00:00 - key "projection_head.weight" could not be found in provided state dict  
INFO - 08/28/22 11:48:22 - 0:00:00 - key "projection_head.bias" could not be found in provided state dict
INFO - 08/28/22 11:48:22 - 0:00:00 - Load pretrained model with msg: _IncompatibleKeys(missing_keys=['projection_head.weight', 'projection_head.bias'], unexpected_keys=['prototypes.weight', 'projection_head.0.weight', 'projection_head.0.bias', 'projection_head.1.weight', 'projection_head.1.bias', 'projection_head.1.running_mean', 'projection_head.1.running_var', 'projection_head.1.num_batches_tracked', 'projection_head.3.weight', 'projection_head.3.bias'])

When I load pre-trained model in eval_linear.py, the output is:

INFO - 08/28/22 14:02:22 - 0:00:02 - Load pretrained model with msg: _IncompatibleKeys(missing_keys=[], unexpected_keys=['projection_head.0.weight', 'projection_head.0.bias', 'projection_head.1.weight', 'projection_head.1.bias', 'projection_head.1.running_mean', 'projection_head.1.running_var', 'projection_head.1.num_batches_tracked', 'projection_head.3.weight', 'projection_head.3.bias', 'prototypes.weight'])

It seems the pretained model has the missing data of projection_head weight and bias. So I printed the keys in saved model and found projection head is in there but with a suffix at the end of saved model:

dict_keys(['module.conv1.weight', 'module.bn1.weight', 'module.bn1.bias', 'module.bn1.running_mean', 'module.bn1.running_var', 'module.bn1.num_batches_tracked', 'module.layer1.0.conv1.weight', 'module.layer1.0.bn1.weight', 'module.layer1.0.bn1.bias', 'module.layer1.0.bn1.running_mean', 'module.layer1.0.bn1.running_var', 'module.layer1.0.bn1.num_batches_tracked', 'module.layer1.0.conv2.weight', 'module.layer1.0.bn2.weight', 'module.layer1.0.bn2.bias', 'module.layer1.0.bn2.running_mean', 'module.layer1.0.bn2.running_var', 'module.layer1.0.bn2.num_batches_tracked', 'module.layer1.0.conv3.weight', 'module.layer1.0.bn3.weight', 'module.layer1.0.bn3.bias', 'module.layer1.0.bn3.running_mean', 'module.layer1.0.bn3.running_var', 'module.layer1.0.bn3.num_batches_tracked', 'module.layer1.0.downsample.0.weight', 'module.layer1.0.downsample.1.weight', 'module.layer1.0.downsample.1.bias', 'module.layer1.0.downsample.1.running_mean', 'module.layer1.0.downsample.1.running_var', 'module.layer1.0.downsample.1.num_batches_tracked', 'module.layer1.1.conv1.weight', 'module.layer1.1.bn1.weight', 'module.layer1.1.bn1.bias', 'module.layer1.1.bn1.running_mean', 'module.layer1.1.bn1.running_var', 'module.layer1.1.bn1.num_batches_tracked', 'module.layer1.1.conv2.weight', 'module.layer1.1.bn2.weight', 'module.layer1.1.bn2.bias', 'module.layer1.1.bn2.running_mean', 'module.layer1.1.bn2.running_var', 'module.layer1.1.bn2.num_batches_tracked', 'module.layer1.1.conv3.weight', 'module.layer1.1.bn3.weight', 'module.layer1.1.bn3.bias', 'module.layer1.1.bn3.running_mean', 'module.layer1.1.bn3.running_var', 'module.layer1.1.bn3.num_batches_tracked', 'module.layer1.2.conv1.weight', 'module.layer1.2.bn1.weight', 'module.layer1.2.bn1.bias', 'module.layer1.2.bn1.running_mean', 'module.layer1.2.bn1.running_var', 'module.layer1.2.bn1.num_batches_tracked', 'module.layer1.2.conv2.weight', 'module.layer1.2.bn2.weight', 'module.layer1.2.bn2.bias', 'module.layer1.2.bn2.running_mean', 'module.layer1.2.bn2.running_var', 'module.layer1.2.bn2.num_batches_tracked', 'module.layer1.2.conv3.weight', 'module.layer1.2.bn3.weight', 'module.layer1.2.bn3.bias', 'module.layer1.2.bn3.running_mean', 'module.layer1.2.bn3.running_var', 'module.layer1.2.bn3.num_batches_tracked', 'module.layer2.0.conv1.weight', 'module.layer2.0.bn1.weight', 'module.layer2.0.bn1.bias', 'module.layer2.0.bn1.running_mean', 'module.layer2.0.bn1.running_var', 'module.layer2.0.bn1.num_batches_tracked', 'module.layer2.0.conv2.weight', 'module.layer2.0.bn2.weight', 'module.layer2.0.bn2.bias', 'module.layer2.0.bn2.running_mean', 'module.layer2.0.bn2.running_var', 'module.layer2.0.bn2.num_batches_tracked', 'module.layer2.0.conv3.weight', 'module.layer2.0.bn3.weight', 'module.layer2.0.bn3.bias', 'module.layer2.0.bn3.running_mean', 'module.layer2.0.bn3.running_var', 'module.layer2.0.bn3.num_batches_tracked', 'module.layer2.0.downsample.0.weight', 'module.layer2.0.downsample.1.weight', 'module.layer2.0.downsample.1.bias', 'module.layer2.0.downsample.1.running_mean', 'module.layer2.0.downsample.1.running_var', 'module.layer2.0.downsample.1.num_batches_tracked', 'module.layer2.1.conv1.weight', 'module.layer2.1.bn1.weight', 'module.layer2.1.bn1.bias', 'module.layer2.1.bn1.running_mean', 'module.layer2.1.bn1.running_var', 'module.layer2.1.bn1.num_batches_tracked', 'module.layer2.1.conv2.weight', 'module.layer2.1.bn2.weight', 'module.layer2.1.bn2.bias', 'module.layer2.1.bn2.running_mean', 'module.layer2.1.bn2.running_var', 'module.layer2.1.bn2.num_batches_tracked', 'module.layer2.1.conv3.weight', 'module.layer2.1.bn3.weight', 'module.layer2.1.bn3.bias', 'module.layer2.1.bn3.running_mean', 'module.layer2.1.bn3.running_var', 'module.layer2.1.bn3.num_batches_tracked', 'module.layer2.2.conv1.weight', 'module.layer2.2.bn1.weight', 'module.layer2.2.bn1.bias', 'module.layer2.2.bn1.running_mean', 'module.layer2.2.bn1.running_var', 'module.layer2.2.bn1.num_batches_tracked', 'module.layer2.2.conv2.weight', 'module.layer2.2.bn2.weight', 'module.layer2.2.bn2.bias', 'module.layer2.2.bn2.running_mean', 'module.layer2.2.bn2.running_var', 'module.layer2.2.bn2.num_batches_tracked', 'module.layer2.2.conv3.weight', 'module.layer2.2.bn3.weight', 'module.layer2.2.bn3.bias', 'module.layer2.2.bn3.running_mean', 'module.layer2.2.bn3.running_var', 'module.layer2.2.bn3.num_batches_tracked', 'module.layer2.3.conv1.weight', 'module.layer2.3.bn1.weight', 'module.layer2.3.bn1.bias', 'module.layer2.3.bn1.running_mean', 'module.layer2.3.bn1.running_var', 'module.layer2.3.bn1.num_batches_tracked', 'module.layer2.3.conv2.weight', 'module.layer2.3.bn2.weight', 'module.layer2.3.bn2.bias', 'module.layer2.3.bn2.running_mean', 'module.layer2.3.bn2.running_var', 'module.layer2.3.bn2.num_batches_tracked', 'module.layer2.3.conv3.weight', 'module.layer2.3.bn3.weight', 'module.layer2.3.bn3.bias', 'module.layer2.3.bn3.running_mean', 'module.layer2.3.bn3.running_var', 'module.layer2.3.bn3.num_batches_tracked', 'module.layer3.0.conv1.weight', 'module.layer3.0.bn1.weight', 'module.layer3.0.bn1.bias', 'module.layer3.0.bn1.running_mean', 'module.layer3.0.bn1.running_var', 'module.layer3.0.bn1.num_batches_tracked', 'module.layer3.0.conv2.weight', 'module.layer3.0.bn2.weight', 'module.layer3.0.bn2.bias', 'module.layer3.0.bn2.running_mean', 'module.layer3.0.bn2.running_var', 'module.layer3.0.bn2.num_batches_tracked', 'module.layer3.0.conv3.weight', 'module.layer3.0.bn3.weight', 'module.layer3.0.bn3.bias', 'module.layer3.0.bn3.running_mean', 'module.layer3.0.bn3.running_var', 'module.layer3.0.bn3.num_batches_tracked', 'module.layer3.0.downsample.0.weight', 'module.layer3.0.downsample.1.weight', 'module.layer3.0.downsample.1.bias', 'module.layer3.0.downsample.1.running_mean', 'module.layer3.0.downsample.1.running_var', 'module.layer3.0.downsample.1.num_batches_tracked', 'module.layer3.1.conv1.weight', 'module.layer3.1.bn1.weight', 'module.layer3.1.bn1.bias', 'module.layer3.1.bn1.running_mean', 'module.layer3.1.bn1.running_var', 'module.layer3.1.bn1.num_batches_tracked', 'module.layer3.1.conv2.weight', 'module.layer3.1.bn2.weight', 'module.layer3.1.bn2.bias', 'module.layer3.1.bn2.running_mean', 'module.layer3.1.bn2.running_var', 'module.layer3.1.bn2.num_batches_tracked', 'module.layer3.1.conv3.weight', 'module.layer3.1.bn3.weight', 'module.layer3.1.bn3.bias', 'module.layer3.1.bn3.running_mean', 'module.layer3.1.bn3.running_var', 'module.layer3.1.bn3.num_batches_tracked', 'module.layer3.2.conv1.weight', 'module.layer3.2.bn1.weight', 'module.layer3.2.bn1.bias', 'module.layer3.2.bn1.running_mean', 'module.layer3.2.bn1.running_var', 'module.layer3.2.bn1.num_batches_tracked', 'module.layer3.2.conv2.weight', 'module.layer3.2.bn2.weight', 'module.layer3.2.bn2.bias', 'module.layer3.2.bn2.running_mean', 'module.layer3.2.bn2.running_var', 'module.layer3.2.bn2.num_batches_tracked', 'module.layer3.2.conv3.weight', 'module.layer3.2.bn3.weight', 'module.layer3.2.bn3.bias', 'module.layer3.2.bn3.running_mean', 'module.layer3.2.bn3.running_var', 'module.layer3.2.bn3.num_batches_tracked', 'module.layer3.3.conv1.weight', 'module.layer3.3.bn1.weight', 'module.layer3.3.bn1.bias', 'module.layer3.3.bn1.running_mean', 'module.layer3.3.bn1.running_var', 'module.layer3.3.bn1.num_batches_tracked', 'module.layer3.3.conv2.weight', 'module.layer3.3.bn2.weight', 'module.layer3.3.bn2.bias', 'module.layer3.3.bn2.running_mean', 'module.layer3.3.bn2.running_var', 'module.layer3.3.bn2.num_batches_tracked', 'module.layer3.3.conv3.weight', 'module.layer3.3.bn3.weight', 'module.layer3.3.bn3.bias', 'module.layer3.3.bn3.running_mean', 'module.layer3.3.bn3.running_var', 'module.layer3.3.bn3.num_batches_tracked', 'module.layer3.4.conv1.weight', 'module.layer3.4.bn1.weight', 'module.layer3.4.bn1.bias', 'module.layer3.4.bn1.running_mean', 'module.layer3.4.bn1.running_var', 'module.layer3.4.bn1.num_batches_tracked', 'module.layer3.4.conv2.weight', 'module.layer3.4.bn2.weight', 'module.layer3.4.bn2.bias', 'module.layer3.4.bn2.running_mean', 'module.layer3.4.bn2.running_var', 'module.layer3.4.bn2.num_batches_tracked', 'module.layer3.4.conv3.weight', 'module.layer3.4.bn3.weight', 'module.layer3.4.bn3.bias', 'module.layer3.4.bn3.running_mean', 'module.layer3.4.bn3.running_var', 'module.layer3.4.bn3.num_batches_tracked', 'module.layer3.5.conv1.weight', 'module.layer3.5.bn1.weight', 'module.layer3.5.bn1.bias', 'module.layer3.5.bn1.running_mean', 'module.layer3.5.bn1.running_var', 'module.layer3.5.bn1.num_batches_tracked', 'module.layer3.5.conv2.weight', 'module.layer3.5.bn2.weight', 'module.layer3.5.bn2.bias', 'module.layer3.5.bn2.running_mean', 'module.layer3.5.bn2.running_var', 'module.layer3.5.bn2.num_batches_tracked', 'module.layer3.5.conv3.weight', 'module.layer3.5.bn3.weight', 'module.layer3.5.bn3.bias', 'module.layer3.5.bn3.running_mean', 'module.layer3.5.bn3.running_var', 'module.layer3.5.bn3.num_batches_tracked', 'module.layer4.0.conv1.weight', 'module.layer4.0.bn1.weight', 'module.layer4.0.bn1.bias', 'module.layer4.0.bn1.running_mean', 'module.layer4.0.bn1.running_var', 'module.layer4.0.bn1.num_batches_tracked', 'module.layer4.0.conv2.weight', 'module.layer4.0.bn2.weight', 'module.layer4.0.bn2.bias', 'module.layer4.0.bn2.running_mean', 'module.layer4.0.bn2.running_var', 'module.layer4.0.bn2.num_batches_tracked', 'module.layer4.0.conv3.weight', 'module.layer4.0.bn3.weight', 'module.layer4.0.bn3.bias', 'module.layer4.0.bn3.running_mean', 'module.layer4.0.bn3.running_var', 'module.layer4.0.bn3.num_batches_tracked', 'module.layer4.0.downsample.0.weight', 'module.layer4.0.downsample.1.weight', 'module.layer4.0.downsample.1.bias', 'module.layer4.0.downsample.1.running_mean', 'module.layer4.0.downsample.1.running_var', 'module.layer4.0.downsample.1.num_batches_tracked', 'module.layer4.1.conv1.weight', 'module.layer4.1.bn1.weight', 'module.layer4.1.bn1.bias', 'module.layer4.1.bn1.running_mean', 'module.layer4.1.bn1.running_var', 'module.layer4.1.bn1.num_batches_tracked', 'module.layer4.1.conv2.weight', 'module.layer4.1.bn2.weight', 'module.layer4.1.bn2.bias', 'module.layer4.1.bn2.running_mean', 'module.layer4.1.bn2.running_var', 'module.layer4.1.bn2.num_batches_tracked', 'module.layer4.1.conv3.weight', 'module.layer4.1.bn3.weight', 'module.layer4.1.bn3.bias', 'module.layer4.1.bn3.running_mean', 'module.layer4.1.bn3.running_var', 'module.layer4.1.bn3.num_batches_tracked', 'module.layer4.2.conv1.weight', 'module.layer4.2.bn1.weight', 'module.layer4.2.bn1.bias', 'module.layer4.2.bn1.running_mean', 'module.layer4.2.bn1.running_var', 'module.layer4.2.bn1.num_batches_tracked', 'module.layer4.2.conv2.weight', 'module.layer4.2.bn2.weight', 'module.layer4.2.bn2.bias', 'module.layer4.2.bn2.running_mean', 'module.layer4.2.bn2.running_var', 'module.layer4.2.bn2.num_batches_tracked', 'module.layer4.2.conv3.weight', 'module.layer4.2.bn3.weight', 'module.layer4.2.bn3.bias', 'module.layer4.2.bn3.running_mean', 'module.layer4.2.bn3.running_var', 'module.layer4.2.bn3.num_batches_tracked', 'module.projection_head.0.weight', 'module.projection_head.0.bias', 'module.projection_head.1.weight', 'module.projection_head.1.bias', 'module.projection_head.1.running_mean', 'module.projection_head.1.running_var', 'module.projection_head.1.num_batches_tracked', 'module.projection_head.3.weight', 'module.projection_head.3.bias', 'module.prototypes.weight'])

How should I handle this problem if I want to load projection head into eval_semisup.py?

Thanks