Error while Loading vit_small weights
shubhaminnani opened this issue · 6 comments
Hi @Xiyue-Wang ,
Thank you for the amazing repo.
I am trying to load the weights for MoCoV3 vit_small pretrained weights.
model = moco.builder_infence.MoCo_ViT(partial(vits.__dict__['vit_small'], stop_grad_conv1=True))
pretext_model = torch.load('TransPath/vit_small.pth.tar')['state_dict']
model = nn.DataParallel(model).cuda()
model.load_state_dict(pretext_model,strict=False)
I checked the keys for model and weights and seems to be same, but still above error.
Looking forward.
Thanks!
you can try
model = moco.builder_infence.MoCo_ViT(
partial(vits.dict[args.arch], stop_grad_conv1=True))
pretext_model = torch.load(r'./vit_small.pth.tar')['state_dict']
model = nn.DataParallel(model).cuda()
model.load_state_dict(pretext_model, strict=True)?
Many people use is no problem, I do not know why you have an error
Do I need to use model.module.online_encoder.net.head = nn.Identity()
similar kind of code from TransPath to extract the features?
Hi @Xiyue-Wang ,
I think I found the problem. The model you trained was with moco.builder
and while doing inference you are trying to call the module from moco.builder_infence
where the in forward function below code is commented out. Training weights have keys from the former module and they dont match in the moco.builder_infence.
As understanding the code, it seems feature extraction should be done from base encoder only rather the complete architecture, so for that need to find a solution. Better load the complete weights and truncate the model after that. Whats your view?
Thanks!