sarulab-speech/UTMOSv2

load weights error

splinter21 opened this issue · 3 comments

☑️ Checks

✏️ Description

File "UTMOSv2-main/utmosv2/utils/task_dependents.py", line 84, in get_model
model.load_state_dict(torch.load(weight_path))
File "/opt/conda/envs/py39webui/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SSLMultiSpecExtModelV2:
Missing key(s) in state_dict: "ssl.encoder.model.encoder.pos_conv_embed.conv.weight_g", "ssl.encoder.model.encoder.pos_conv_embed.conv.weight_v".
Unexpected key(s) in state_dict: "ssl.encoder.model.encoder.pos_conv_embed.conv.parametrizations.weight.original0", "ssl.encoder.model.encoder.pos_conv_embed.conv.parametrizations.weight.original1".

Environment

py39
torch 2.0.1
transformers 4.42.4

Expected behavior

inference with no error

Steps to reproduce

python inference.py --input_path xxx.wav --out_path 123.csv

Additional notes

Thank you for your report and PR! I will check the cause and get back to you within a few days.

@splinter21
The cause of this error is that torch.nn.utils.weight_norm has been deprecated and moved to torch.nn.utils.parametrizations.weight_norm. If you update the PyTorch version to 2.3.1 or higher, as specified in the pyproject.toml, the error should be resolved and the code will work correctly. Additionally, I confirmed that this issue did not occur in version 2.1.1.
Or, is there a specific reason it needs to be version 2.0.1?

I know. Let me upgrade my pytorch. Thanks!