Doing transfer learning
JJavierga opened this issue · 0 comments
I'm trying to do transfer learning with this NN. I have tried loading already known parameters first and then set the shallowest layers parameters to be frozen like this:
`# Get an instance of the model
enet = ENet(nc)
print ('[INFO]Model Instantiated!')
# Move the model to cuda if available
enet = enet.to(device)
# Transfer learnt weights
pretrained_dict = torch.load('./datasets/CamVid/ckpt-enet.pth')['state_dict']
model_dict = enet.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
enet.load_state_dict(model_dict)
# Choose frozen layers
count=0
for child in enet.children():
if count<frozen_layers:
for param in child.parameters():
param.requires_grad=False
count+=1
`
But I get an error in load_state_dict(model_dict) saying that parameters do not match, not even the ones from the very first layers:
[INFO]Defined all the hyperparameters successfully!
[INFO]Model Instantiated!
Traceback (most recent call last):
File "Transfer_learning.py", line 153, in
train(FLAGS,27) #This 27 is the number of layers to freeze
File "/home/javier/Documents/Segmentation/ENet/ENet-Real-Time-Semantic-Segmentation/new_train.py", line 45, in train
enet.load_state_dict(model_dict)
File "/home/javier/anaconda3/envs/ENet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ENet:
size mismatch for b10.batchnorm2.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([4]).
size mismatch for b10.batchnorm2.bias: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([4]).
size mismatch for b10.batchnorm2.running_mean: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([4]).
...
size mismatch for b51.batchnorm2.running_var: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([4]).