ZhengPeng7/BiRefNet

RuntimeError: Error(s) in loading state_dict for BiRefNet

SSJang opened this issue · 2 comments

File "inference.py", line 82, in main
model.load_state_dict(state_dict)
File "/home/js/AI_run/BiRef_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BiRefNet:
Missing key(s) in state_dict: "squeeze_module.0.dec_att.global_avg_pool.2.weight", "squeeze_module.0.dec_att.global_avg_pool.2.bias", "squeeze_module.0.dec_att.global_avg_pool.2.running_mean", "squeeze_module.0.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block4.dec_att.global_avg_pool.2.weight", "decoder.decoder_block4.dec_att.global_avg_pool.2.bias", "decoder.decoder_block4.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block4.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block3.dec_att.global_avg_pool.2.weight", "decoder.decoder_block3.dec_att.global_avg_pool.2.bias", "decoder.decoder_block3.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block3.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block2.dec_att.global_avg_pool.2.weight", "decoder.decoder_block2.dec_att.global_avg_pool.2.bias", "decoder.decoder_block2.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block2.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block1.dec_att.global_avg_pool.2.weight", "decoder.decoder_block1.dec_att.global_avg_pool.2.bias", "decoder.decoder_block1.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block1.dec_att.global_avg_pool.2.running_var".

When attempting to perform inference after training, several variables cannot be loaded. The environment used for training was torch 1.12.1+cu113, and for inference, torch 1.13.1+cu116 was used. Could this difference be causing the issue?

If I use model.load_state_dict(state_dict, strict=False) for loading the model, will there be a significant difference in performance?

Hi, you mean you used torch in different versions for training and inference? Different versions did make some differences in the keys of saved weights.
I haven't tried that before. But since the errors show that the inconsistent keys are the xx_pool.xx keys, I guess they might not differ from each other.
However, I still recommend you use the same version if possible. More specifically, use PyTorch 2.0.1, as I mentioned in README to turn on the compile operation for around 30% speed up in training :)

Thank you for your comment!