eric-mitchell/direct-preference-optimization

How to load trained model for inference?

VibhuAg opened this issue · 2 comments

I have seen issue #13 and am trying to load the model weights to perform inference. So far I have trained Pythia-2.8B with SFT and then DPO as suggested in the README. I am trying to follow the instructions given in issue #13 to load the model weights after training but I cannot seem to find them. In the output folder for the training run, I see three files policy.pt, optimizer.pt, and scheduler.pt. I am trying to load the weights as follows:

import torch
import transformers

model = transformers.AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-2.8b')
model.load_state_dict(torch.load('<PATH_TO_OUTPUT_DIR>/LATEST/policy.pt'))

But I run into numerous missing key errors. What am I doing wrong?

I encountered the same problem, did you solve it? Looking forward to reply!

Sorry for the confusion here- you need to use the 'state' key in the saved dict. So:

model.load_state_dict(torch.load('<PATH_TO_OUTPUT_DIR>/LATEST/policy.pt')['state'])