problem, 3D weights cannot be loaded
cqustlym opened this issue · 1 comments
cqustlym commented
Thank you for providing a lot of 3D pre-training weights. When I loaded the weights of VC3D_kenshohara, I ran the following code and found that all weights were Non-Pretrained keys: 318. Is there any problem?
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.wide_resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model()
pretrain_path = r"wideresnet-50-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
keys ( Current model,C ) 318 odict_keys(['conv1.weight', 'bn1.weight', 'bn1.bias', 's_tracked'])
keys ( Pre-trained ,P ) 267 dict_keys(['module.conv1.weight', 'module.bn1.weight...er4.2.convle.fc.bias'])
keys ( In C & In P ) 0 dict_keys([])
keys ( NoIn C & In P ) 267 dict_keys(['module.co...odule.layer4.2.bas'])
keys ( In C & NoIn P ) 318 dict_keys(['conv1.weig...es_tracked'])
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Pretrained keys : 0 dict_keys([])
Non-Pretrained keys: 318 dict_keys(['conv1.weight', 'bn1.we...tches_tracked'])
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
WAMAWAMA commented
Because of the keys in pretrained weights state_dict
contains 'model.', so you need to set the drop_modelDOT
argument to True
in the function load_weights
, like:
m = load_weights(m, pretrain_weights, drop_modelDOT=True)