Mismatch of the checkpoint
JinGuang-cuhksz opened this issue · 2 comments
Thank you very much for sharing the pre-trained models. But I found some unexpected keys with the following error:
RuntimeError: Error(s) in loading state_dict for QGPO_Critic:
Unexpected key(s) in state_dict: "vf.v.0.weight", "vf.v.0.bias", "vf.v.2.weight", "vf.v.2.bias", "vf.v.4.weight", "vf.v.4.bias", "gaussian_policy.net.0.weight", "gaussian_policy.net.0.bias", "gaussian_policy.net.2.weight", "gaussian_policy.net.2.bias", "gaussian_policy.net.4.weight", "gaussian_policy.net.4.bias".
I also didn't find the nets, vf
and gaussian_policy
in the file model.py
. Could you help me what's wrong?
Hi, sorry for the late reply.
vf
and gaussian_policy
models are the ones I used for debugging. You do not need these model keys.
Could you try to ignore these keys yourself? It would be something like (from GPT):
To load a part of a PyTorch checkpoint, such as only model1
from a checkpoint that contains model1
, model2
, and model3
, you can use PyTorch's load_state_dict
method in combination with a filtering strategy to ignore the weights of model2
and model3
. Here's a basic guide on how to do this:
-
Load the Entire Checkpoint: First, load the entire checkpoint file. This is typically done using the
torch.load
function. -
Filter the Weights: Next, you need to filter out the weights for
model1
from the loaded checkpoint. This usually involves iterating through the checkpoint dictionary and selecting keys that are related tomodel1
. -
Load the Weights into Your Model: Finally, use the
load_state_dict
method to load the filtered weights into your model. You might need to use thestrict=False
parameter to allow partial loading.
Here's an example in code:
import torch
# Assuming your model is defined like this
class YourModel(torch.nn.Module):
# Your model definition here
pass
# Load the entire checkpoint
checkpoint = torch.load('path_to_your_checkpoint.ckpt')
# Filter out the part you need, e.g., keep only model1's weights
model1_state_dict = {k: v for k, v in checkpoint.items() if k.startswith('model1.')}
# Initialize your model
your_model = YourModel()
# Load the weights
your_model.load_state_dict(model1_state_dict, strict=False)
In this example, your_model
is your own model definition, and we assume it contains only the parts needed for model1
. The load_state_dict(model1_state_dict, strict=False)
allows you to load only the weights present in model1_state_dict
, ignoring other parts of the model that are not present. This way, the weights of model2
and model3
are effectively ignored.
Welcome to share your code or create a PR if you can solve this. If you still have problems, I will probably be available for help after 25/11.
Thanks a lot. load_state_dict(model1_state_dict, strict=False)
replacing load_state_dict(model1_state_dict)
in the code is helpful.