Error Message After Using a fine tuned ASR Model
Opened this issue · 4 comments
GUUser91 commented
I get this error message after using a fine tuned ASR Model
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[14], line 5
4 try:
----> 5 model[key].load_state_dict(params[key])
6 except:
File StyleTTS2/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict(self, state_dict, strict, assign)
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.__class__.__name__, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for ASRCNN:
Missing key(s) in state_dict: "to_mfcc.dct_mat", "init_cnn.conv.weight", "init_cnn.conv.bias", "cnns.0.0.blocks.0.0.conv.weight", "cnns.0.0.blocks.0.0.conv.bias", "cnns.0.0.blocks.0.2.weight", "cnns.0.0.blocks.0.2.bias", "cnns.0.0.blocks.0.4.conv.weight", "cnns.0.0.blocks.0.4.conv.bias", "cnns.0.0.blocks.1.0.conv.weight", "cnns.0.0.blocks.1.0.conv.bias", "cnns.0.0.blocks.1.2.weight", "cnns.0.0.blocks.1.2.bias", "cnns.0.0.blocks.1.4.conv.weight", "cnns.0.0.blocks.1.4.conv.bias", "cnns.0.0.blocks.2.0.conv.weight", "cnns.0.0.blocks.2.0.conv.bias", "cnns.0.0.blocks.2.2.weight", "cnns.0.0.blocks.2.2.bias", "cnns.0.0.blocks.2.4.conv.weight", "cnns.0.0.blocks.2.4.conv.bias", "cnns.0.1.weight", "cnns.0.1.bias", "cnns.1.0.blocks.0.0.conv.weight", "cnns.1.0.blocks.0.0.conv.bias", "cnns.1.0.blocks.0.2.weight", "cnns.1.0.blocks.0.2.bias", "cnns.1.0.blocks.0.4.conv.weight", "cnns.1.0.blocks.0.4.conv.bias", "cnns.1.0.blocks.1.0.conv.weight", "cnns.1.0.blocks.1.0.conv.bias", "cnns.1.0.blocks.1.2.weight", "cnns.1.0.blocks.1.2.bias", "cnns.1.0.blocks.1.4.conv.weight", "cnns.1.0.blocks.1.4.conv.bias", "cnns.1.0.blocks.2.0.conv.weight", "cnns.1.0.blocks.2.0.conv.bias", "cnns.1.0.blocks.2.2.weight", "cnns.1.0.blocks.2.2.bias", "cnns.1.0.blocks.2.4.conv.weight", "cnns.1.0.blocks.2.4.conv.bias", "cnns.1.1.weight", "cnns.1.1.bias", "cnns.2.0.blocks.0.0.conv.weight", "cnns.2.0.blocks.0.0.conv.bias", "cnns.2.0.blocks.0.2.weight", "cnns.2.0.blocks.0.2.bias", "cnns.2.0.blocks.0.4.conv.weight", "cnns.2.0.blocks.0.4.conv.bias", "cnns.2.0.blocks.1.0.conv.weight", "cnns.2.0.blocks.1.0.conv.bias", "cnns.2.0.blocks.1.2.weight", "cnns.2.0.blocks.1.2.bias", "cnns.2.0.blocks.1.4.conv.weight", "cnns.2.0.blocks.1.4.conv.bias", "cnns.2.0.blocks.2.0.conv.weight", "cnns.2.0.blocks.2.0.conv.bias", "cnns.2.0.blocks.2.2.weight", "cnns.2.0.blocks.2.2.bias", "cnns.2.0.blocks.2.4.conv.weight", "cnns.2.0.blocks.2.4.conv.bias", "cnns.2.1.weight", "cnns.2.1.bias", "cnns.3.0.blocks.0.0.conv.weight", "cnns.3.0.blocks.0.0.conv.bias", "cnns.3.0.blocks.0.2.weight", "cnns.3.0.blocks.0.2.bias", "cnns.3.0.blocks.0.4.conv.weight", "cnns.3.0.blocks.0.4.conv.bias", "cnns.3.0.blocks.1.0.conv.weight", "cnns.3.0.blocks.1.0.conv.bias", "cnns.3.0.blocks.1.2.weight", "cnns.3.0.blocks.1.2.bias", "cnns.3.0.blocks.1.4.conv.weight", "cnns.3.0.blocks.1.4.conv.bias", "cnns.3.0.blocks.2.0.conv.weight", "cnns.3.0.blocks.2.0.conv.bias", "cnns.3.0.blocks.2.2.weight", "cnns.3.0.blocks.2.2.bias", "cnns.3.0.blocks.2.4.conv.weight", "cnns.3.0.blocks.2.4.conv.bias", "cnns.3.1.weight", "cnns.3.1.bias", "cnns.4.0.blocks.0.0.conv.weight", "cnns.4.0.blocks.0.0.conv.bias", "cnns.4.0.blocks.0.2.weight", "cnns.4.0.blocks.0.2.bias", "cnns.4.0.blocks.0.4.conv.weight", "cnns.4.0.blocks.0.4.conv.bias", "cnns.4.0.blocks.1.0.conv.weight", "cnns.4.0.blocks.1.0.conv.bias", "cnns.4.0.blocks.1.2.weight", "cnns.4.0.blocks.1.2.bias", "cnns.4.0.blocks.1.4.conv.weight", "cnns.4.0.blocks.1.4.conv.bias", "cnns.4.0.blocks.2.0.conv.weight", "cnns.4.0.blocks.2.0.conv.bias", "cnns.4.0.blocks.2.2.weight", "cnns.4.0.blocks.2.2.bias", "cnns.4.0.blocks.2.4.conv.weight", "cnns.4.0.blocks.2.4.conv.bias", "cnns.4.1.weight", "cnns.4.1.bias", "cnns.5.0.blocks.0.0.conv.weight", "cnns.5.0.blocks.0.0.conv.bias", "cnns.5.0.blocks.0.2.weight", "cnns.5.0.blocks.0.2.bias", "cnns.5.0.blocks.0.4.conv.weight", "cnns.5.0.blocks.0.4.conv.bias", "cnns.5.0.blocks.1.0.conv.weight", "cnns.5.0.blocks.1.0.conv.bias", "cnns.5.0.blocks.1.2.weight", "cnns.5.0.blocks.1.2.bias", "cnns.5.0.blocks.1.4.conv.weight", "cnns.5.0.blocks.1.4.conv.bias", "cnns.5.0.blocks.2.0.conv.weight", "cnns.5.0.blocks.2.0.conv.bias", "cnns.5.0.blocks.2.2.weight", "cnns.5.0.blocks.2.2.bias", "cnns.5.0.blocks.2.4.conv.weight", "cnns.5.0.blocks.2.4.conv.bias", "cnns.5.1.weight", "cnns.5.1.bias", "projection.conv.weight", "projection.conv.bias", "ctc_linear.0.linear_layer.weight", "ctc_linear.0.linear_layer.bias", "ctc_linear.2.linear_layer.weight", "ctc_linear.2.linear_layer.bias", "asr_s2s.embedding.weight", "asr_s2s.project_to_n_symbols.weight", "asr_s2s.project_to_n_symbols.bias", "asr_s2s.attention_layer.query_layer.linear_layer.weight", "asr_s2s.attention_layer.memory_layer.linear_layer.weight", "asr_s2s.attention_layer.v.linear_layer.weight", "asr_s2s.attention_layer.location_layer.location_conv.conv.weight", "asr_s2s.attention_layer.location_layer.location_dense.linear_layer.weight", "asr_s2s.decoder_rnn.weight_ih", "asr_s2s.decoder_rnn.weight_hh", "asr_s2s.decoder_rnn.bias_ih", "asr_s2s.decoder_rnn.bias_hh", "asr_s2s.project_to_hidden.0.linear_layer.weight", "asr_s2s.project_to_hidden.0.linear_layer.bias".
Unexpected key(s) in state_dict: "module.to_mfcc.dct_mat", "module.init_cnn.conv.weight", "module.init_cnn.conv.bias", "module.cnns.0.0.blocks.0.0.conv.weight", "module.cnns.0.0.blocks.0.0.conv.bias", "module.cnns.0.0.blocks.0.2.weight", "module.cnns.0.0.blocks.0.2.bias", "module.cnns.0.0.blocks.0.4.conv.weight", "module.cnns.0.0.blocks.0.4.conv.bias", "module.cnns.0.0.blocks.1.0.conv.weight", "module.cnns.0.0.blocks.1.0.conv.bias", "module.cnns.0.0.blocks.1.2.weight", "module.cnns.0.0.blocks.1.2.bias", "module.cnns.0.0.blocks.1.4.conv.weight", "module.cnns.0.0.blocks.1.4.conv.bias", "module.cnns.0.0.blocks.2.0.conv.weight", "module.cnns.0.0.blocks.2.0.conv.bias", "module.cnns.0.0.blocks.2.2.weight", "module.cnns.0.0.blocks.2.2.bias", "module.cnns.0.0.blocks.2.4.conv.weight", "module.cnns.0.0.blocks.2.4.conv.bias", "module.cnns.0.1.weight", "module.cnns.0.1.bias", "module.cnns.1.0.blocks.0.0.conv.weight", "module.cnns.1.0.blocks.0.0.conv.bias", "module.cnns.1.0.blocks.0.2.weight", "module.cnns.1.0.blocks.0.2.bias", "module.cnns.1.0.blocks.0.4.conv.weight", "module.cnns.1.0.blocks.0.4.conv.bias", "module.cnns.1.0.blocks.1.0.conv.weight", "module.cnns.1.0.blocks.1.0.conv.bias", "module.cnns.1.0.blocks.1.2.weight", "module.cnns.1.0.blocks.1.2.bias", "module.cnns.1.0.blocks.1.4.conv.weight", "module.cnns.1.0.blocks.1.4.conv.bias", "module.cnns.1.0.blocks.2.0.conv.weight", "module.cnns.1.0.blocks.2.0.conv.bias", "module.cnns.1.0.blocks.2.2.weight", "module.cnns.1.0.blocks.2.2.bias", "module.cnns.1.0.blocks.2.4.conv.weight", "module.cnns.1.0.blocks.2.4.conv.bias", "module.cnns.1.1.weight", "module.cnns.1.1.bias", "module.cnns.2.0.blocks.0.0.conv.weight", "module.cnns.2.0.blocks.0.0.conv.bias", "module.cnns.2.0.blocks.0.2.weight", "module.cnns.2.0.blocks.0.2.bias", "module.cnns.2.0.blocks.0.4.conv.weight", "module.cnns.2.0.blocks.0.4.conv.bias", "module.cnns.2.0.blocks.1.0.conv.weight", "module.cnns.2.0.blocks.1.0.conv.bias", "module.cnns.2.0.blocks.1.2.weight", "module.cnns.2.0.blocks.1.2.bias", "module.cnns.2.0.blocks.1.4.conv.weight", "module.cnns.2.0.blocks.1.4.conv.bias", "module.cnns.2.0.blocks.2.0.conv.weight", "module.cnns.2.0.blocks.2.0.conv.bias", "module.cnns.2.0.blocks.2.2.weight", "module.cnns.2.0.blocks.2.2.bias", "module.cnns.2.0.blocks.2.4.conv.weight", "module.cnns.2.0.blocks.2.4.conv.bias", "module.cnns.2.1.weight", "module.cnns.2.1.bias", "module.cnns.3.0.blocks.0.0.conv.weight", "module.cnns.3.0.blocks.0.0.conv.bias", "module.cnns.3.0.blocks.0.2.weight", "module.cnns.3.0.blocks.0.2.bias", "module.cnns.3.0.blocks.0.4.conv.weight", "module.cnns.3.0.blocks.0.4.conv.bias", "module.cnns.3.0.blocks.1.0.conv.weight", "module.cnns.3.0.blocks.1.0.conv.bias", "module.cnns.3.0.blocks.1.2.weight", "module.cnns.3.0.blocks.1.2.bias", "module.cnns.3.0.blocks.1.4.conv.weight", "module.cnns.3.0.blocks.1.4.conv.bias", "module.cnns.3.0.blocks.2.0.conv.weight", "module.cnns.3.0.blocks.2.0.conv.bias", "module.cnns.3.0.blocks.2.2.weight", "module.cnns.3.0.blocks.2.2.bias", "module.cnns.3.0.blocks.2.4.conv.weight", "module.cnns.3.0.blocks.2.4.conv.bias", "module.cnns.3.1.weight", "module.cnns.3.1.bias", "module.cnns.4.0.blocks.0.0.conv.weight", "module.cnns.4.0.blocks.0.0.conv.bias", "module.cnns.4.0.blocks.0.2.weight", "module.cnns.4.0.blocks.0.2.bias", "module.cnns.4.0.blocks.0.4.conv.weight", "module.cnns.4.0.blocks.0.4.conv.bias", "module.cnns.4.0.blocks.1.0.conv.weight", "module.cnns.4.0.blocks.1.0.conv.bias", "module.cnns.4.0.blocks.1.2.weight", "module.cnns.4.0.blocks.1.2.bias", "module.cnns.4.0.blocks.1.4.conv.weight", "module.cnns.4.0.blocks.1.4.conv.bias", "module.cnns.4.0.blocks.2.0.conv.weight", "module.cnns.4.0.blocks.2.0.conv.bias", "module.cnns.4.0.blocks.2.2.weight", "module.cnns.4.0.blocks.2.2.bias", "module.cnns.4.0.blocks.2.4.conv.weight", "module.cnns.4.0.blocks.2.4.conv.bias", "module.cnns.4.1.weight", "module.cnns.4.1.bias", "module.cnns.5.0.blocks.0.0.conv.weight", "module.cnns.5.0.blocks.0.0.conv.bias", "module.cnns.5.0.blocks.0.2.weight", "module.cnns.5.0.blocks.0.2.bias", "module.cnns.5.0.blocks.0.4.conv.weight", "module.cnns.5.0.blocks.0.4.conv.bias", "module.cnns.5.0.blocks.1.0.conv.weight", "module.cnns.5.0.blocks.1.0.conv.bias", "module.cnns.5.0.blocks.1.2.weight", "module.cnns.5.0.blocks.1.2.bias", "module.cnns.5.0.blocks.1.4.conv.weight", "module.cnns.5.0.blocks.1.4.conv.bias", "module.cnns.5.0.blocks.2.0.conv.weight", "module.cnns.5.0.blocks.2.0.conv.bias", "module.cnns.5.0.blocks.2.2.weight", "module.cnns.5.0.blocks.2.2.bias", "module.cnns.5.0.blocks.2.4.conv.weight", "module.cnns.5.0.blocks.2.4.conv.bias", "module.cnns.5.1.weight", "module.cnns.5.1.bias", "module.projection.conv.weight", "module.projection.conv.bias", "module.ctc_linear.0.linear_layer.weight", "module.ctc_linear.0.linear_layer.bias", "module.ctc_linear.2.linear_layer.weight", "module.ctc_linear.2.linear_layer.bias", "module.asr_s2s.embedding.weight", "module.asr_s2s.project_to_n_symbols.weight", "module.asr_s2s.project_to_n_symbols.bias", "module.asr_s2s.attention_layer.query_layer.linear_layer.weight", "module.asr_s2s.attention_layer.memory_layer.linear_layer.weight", "module.asr_s2s.attention_layer.v.linear_layer.weight", "module.asr_s2s.attention_layer.location_layer.location_conv.conv.weight", "module.asr_s2s.attention_layer.location_layer.location_dense.linear_layer.weight", "module.asr_s2s.decoder_rnn.weight_ih", "module.asr_s2s.decoder_rnn.weight_hh", "module.asr_s2s.decoder_rnn.bias_ih", "module.asr_s2s.decoder_rnn.bias_hh", "module.asr_s2s.project_to_hidden.0.linear_layer.weight", "module.asr_s2s.project_to_hidden.0.linear_layer.bias".
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Cell In[14], line 14
12 new_state_dict[name] = v
13 # load params
---> 14 model[key].load_state_dict(new_state_dict, strict=False)
15 # except:
16 # _load(params[key], model[key])
17 _ = [model[key].eval() for key in model]
File StyleTTS2/venv/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict(self, state_dict, strict, assign)
2184 error_msgs.insert(
2185 0, 'Missing key(s) in state_dict: {}. '.format(
2186 ', '.join(f'"{k}"' for k in missing_keys)))
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.__class__.__name__, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for ASRCNN:
size mismatch for ctc_linear.2.linear_layer.weight: copying a param with shape torch.Size([178, 256]) from checkpoint, the shape in current model is torch.Size([80, 256]).
size mismatch for ctc_linear.2.linear_layer.bias: copying a param with shape torch.Size([178]) from checkpoint, the shape in current model is torch.Size([80]).
size mismatch for asr_s2s.embedding.weight: copying a param with shape torch.Size([178, 512]) from checkpoint, the shape in current model is torch.Size([80, 256]).
size mismatch for asr_s2s.project_to_n_symbols.weight: copying a param with shape torch.Size([178, 128]) from checkpoint, the shape in current model is torch.Size([80, 128]).
size mismatch for asr_s2s.project_to_n_symbols.bias: copying a param with shape torch.Size([178]) from checkpoint, the shape in current model is torch.Size([80]).
size mismatch for asr_s2s.decoder_rnn.weight_ih: copying a param with shape torch.Size([512, 640]) from checkpoint, the shape in current model is torch.Size([512, 384]).
MARafey commented
@GUUser91 I was able to solve the issue it seems that the parameters in my config file for fine tunning didn't match the ones in the utils folder
DrBrule commented
I ran into the same issue, so rather than training an ASR from scratch, I just fine-tuned off the checkpoint in the Utils folder and ensured that I used the same configuration as in the pre-trained networks. There's a mismatch between the config in the AuxiliaryASR repo and the StyleTTS2/Utils/ASR folder.