yeyupiaoling/Whisper-Finetune

训练发生异常

Closed this issue · 7 comments

使用项目中的代码微调的时候发生如下异常,不知道是什么原因,我使用的是audiofolder数据集:

File "/usr/local/lib/python3.10/dist-packages/peft/peft_model.py", line 442, in forward
    return self.get_base_model()(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1486, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1346, in forward
    encoder_outputs = self.encoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 896, in forward
    inputs_embeds = nn.functional.gelu(self.conv1(input_features))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 313, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 309, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [1280, 80, 3], expected input[1, 8, 3000] to have 80 channels, but got 8 channels instead

你是改了什么吗?通道数都不对了

@dataclass
class AudioFolderDataset:
    processor: Any

    def _prepare_dataset(self, batch):
        # load and resample audio data from 48 to 16kHz
        audio = batch["audio"]

        # compute log-Mel input features from input audio array
        batch["input_features"] = \
            self.processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

        # encode target text to label ids
        batch["labels"] = self.processor.tokenizer(batch["sentence"]).input_ids
        return batch

    def load(self):
        common_voice = load_dataset("audiofolder", data_dir="dataset")
        common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
        common_voice = common_voice.map(self._prepare_dataset, num_proc=1)
        return common_voice['train'], common_voice['test']

train_dataset, test_dataset = AudioFolderDataset(processor=processor).load()

使用上面的dataset出现的

如果把per_device_train_batch_size设置为80,上面没错误提示了。但是会有另外的错误:

File "/opt/conda/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py", line 1486, in forward
   outputs = self.model(
 File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
   return forward_call(*args, **kwargs)
 File "/opt/conda/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py", line 1346, in forward
   encoder_outputs = self.encoder(
 File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
   return forward_call(*args, **kwargs)
 File "/opt/conda/lib/python3.10/site-packages/transformers/models/whisper/modeling_whisper.py", line 899, in forward
   inputs_embeds = inputs_embeds.permute(0, 2, 1)
RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3

数据集是标准的方法,一直使用在全参数下微调。另外按照如下lora方法微调也是可以的:
https://github.com/Vaibhavs10/fast-whisper-finetuning

不太清楚你的情况,你要使用我的项目微调的话,要看文档,生成我项目所需要的数据格式。

@ILG2021 要改数据前处理的代码:
utils/data_utils.py:47:
input_features = [{"input_features": feature["input_features"][0]} for feature in features]
改成:
input_features = [{"input_features": feature["input_features"]} for feature in features]