ChenDRAG/CEP-energy-guided-diffusion

Mismatch of the checkpoint

JinGuang-cuhksz opened this issue · 2 comments

Thank you very much for sharing the pre-trained models. But I found some unexpected keys with the following error:

RuntimeError: Error(s) in loading state_dict for QGPO_Critic:
Unexpected key(s) in state_dict: "vf.v.0.weight", "vf.v.0.bias", "vf.v.2.weight", "vf.v.2.bias", "vf.v.4.weight", "vf.v.4.bias", "gaussian_policy.net.0.weight", "gaussian_policy.net.0.bias", "gaussian_policy.net.2.weight", "gaussian_policy.net.2.bias", "gaussian_policy.net.4.weight", "gaussian_policy.net.4.bias".

I also didn't find the nets, vf and gaussian_policy in the file model.py. Could you help me what's wrong?

Hi, sorry for the late reply.

vf and gaussian_policy models are the ones I used for debugging. You do not need these model keys.

Could you try to ignore these keys yourself? It would be something like (from GPT):

To load a part of a PyTorch checkpoint, such as only model1 from a checkpoint that contains model1, model2, and model3, you can use PyTorch's load_state_dict method in combination with a filtering strategy to ignore the weights of model2 and model3. Here's a basic guide on how to do this:

  1. Load the Entire Checkpoint: First, load the entire checkpoint file. This is typically done using the torch.load function.

  2. Filter the Weights: Next, you need to filter out the weights for model1 from the loaded checkpoint. This usually involves iterating through the checkpoint dictionary and selecting keys that are related to model1.

  3. Load the Weights into Your Model: Finally, use the load_state_dict method to load the filtered weights into your model. You might need to use the strict=False parameter to allow partial loading.

Here's an example in code:

import torch

# Assuming your model is defined like this
class YourModel(torch.nn.Module):
    # Your model definition here
    pass

# Load the entire checkpoint
checkpoint = torch.load('path_to_your_checkpoint.ckpt')

# Filter out the part you need, e.g., keep only model1's weights
model1_state_dict = {k: v for k, v in checkpoint.items() if k.startswith('model1.')}

# Initialize your model
your_model = YourModel()

# Load the weights
your_model.load_state_dict(model1_state_dict, strict=False)

In this example, your_model is your own model definition, and we assume it contains only the parts needed for model1. The load_state_dict(model1_state_dict, strict=False) allows you to load only the weights present in model1_state_dict, ignoring other parts of the model that are not present. This way, the weights of model2 and model3 are effectively ignored.

Welcome to share your code or create a PR if you can solve this. If you still have problems, I will probably be available for help after 25/11.

Thanks a lot. load_state_dict(model1_state_dict, strict=False) replacing load_state_dict(model1_state_dict) in the code is helpful.