Word Error Rate Increasing post training on whisper-large-v3
Swami-Abhinav opened this issue · 0 comments
Swami-Abhinav commented
I trained using your pipeline on the whisper-large-v3 model and the word error rate is more than 100% and the value is
{'eval/wer': 100.71108101244393, 'eval/normalized_wer': 100.2449409394744}
When I commented out fp16=True it worsened more to :
wer=333.30652670786424 and normalized_wer=190.30308579800385
I used common voice -11 dataset for Hindi language, here is my code and config for the trainer:
from transformers import Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="Abhinav28/large-v3-hi-commonvoice-11-peft-trained-adapter-withfp16", # change to a repo name of your choice
per_device_train_batch_size=8,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-3,
warmup_steps=50,
num_train_epochs=1,
evaluation_strategy="steps",
fp16=True,
per_device_eval_batch_size=8,
generation_max_length=128,
# metric_for_best_model="wer",
logging_steps=819,
# max_steps=10, # only for testing purposes, remove this from your final run :)
# greater_is_better=False,
remove_unused_columns=False, # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
label_names=["labels"], # same reason as above
)
class SavePeftModelCallback(TrainerCallback):
def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
print(peft_model_path)
kwargs["model"].save_pretrained(peft_model_path)
# pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
# if os.path.exists(pytorch_model_path):
# os.remove(pytorch_model_path)
# return control
trainer = Seq2SeqTrainer(
args=training_args,
model=peft_model,
train_dataset=common_voice["train"],
eval_dataset=common_voice["test"],
data_collator=data_collator,
# compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
callbacks=[SavePeftModelCallback],
)
How do I make it have better wer?