How to load trained model for inference?
VibhuAg opened this issue · 2 comments
VibhuAg commented
I have seen issue #13 and am trying to load the model weights to perform inference. So far I have trained Pythia-2.8B with SFT and then DPO as suggested in the README. I am trying to follow the instructions given in issue #13 to load the model weights after training but I cannot seem to find them. In the output folder for the training run, I see three files policy.pt
, optimizer.pt
, and scheduler.pt
. I am trying to load the weights as follows:
import torch
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-2.8b')
model.load_state_dict(torch.load('<PATH_TO_OUTPUT_DIR>/LATEST/policy.pt'))
But I run into numerous missing key errors. What am I doing wrong?
Leonnnnnn929 commented
I encountered the same problem, did you solve it? Looking forward to reply!
eric-mitchell commented
Sorry for the confusion here- you need to use the 'state'
key in the saved dict. So:
model.load_state_dict(torch.load('<PATH_TO_OUTPUT_DIR>/LATEST/policy.pt')['state'])