questions for loading the pretrained_model
Opened this issue · 2 comments
mingbocui commented
def load(self, model_file, pretrain_file):
""" load saved model or pretrained transformer (a part of model) """
if model_file:
print('Loading the model from', model_file)
self.model.load_state_dict(torch.load(model_file))
elif pretrain_file: # use pretrained transformer
print('Loading the pretrained model from', pretrain_file)
if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow
checkpoint.load_model(self.model.transformer, pretrain_file)
elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch
self.model.transformer.load_state_dict(
{key[12:]: value
for key, value in torch.load(pretrain_file).items()
if key.startswith('transformer')}
) # load only transformer parts
Could I kindly ask that what is the meaning of key[12:]: value when you load a pretrained_model? Just want to keep the last layer? Thanks, hope for your reply.
dhlee347 commented
It is because I wanted to load only a transformer part of saved model, not the whole model.