/whisper-medusa

Whisper with Medusa heads

Primary LanguagePythonMIT LicenseMIT

Whisper Medusa

Whisper is an advanced encoder-decoder model for speech transcription and translation, processing audio through encoding and decoding stages. Given its large size and slow inference speed, various optimization strategies like Faster-Whisper and Speculative Decoding have been proposed to enhance performance. Our Medusa model builds on Whisper by predicting multiple tokens per iteration, which significantly improves speed with small degradation in WER. We train and evaluate our model on the LibriSpeech dataset, demonstrating strong performance speed improvements with on-par accuracy compared to the vanilla Whisper model.


Whisper Medusa architecture.


Training and Evaluation Details

Whisper Medusa is based on Whisper large model with 10 Medusa heads. It was trained on the LibriSpeech dataset to perform audio translation. The Medusa heads were optimized for English, so for optimal performance and speed improvements, please use English audio only.

On average, Whisper Medusa achieves x1.5 faster generation compared with the Whisper vanilla with on-par WER (4.2% vs. 4% respectively).


Whisper Medusa speedup compared to Whisper vanilla.


Installation

Start with creating a virtual environment and activating it:

conda create -n whisper-medusa python=3.11 -y
conda activate whisper-medusa
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118

Then install the package:

git clone https://github.com/aiola-lab/whisper-medusa.git
cd whisper-medusa
pip install -e .

Usage

Inference can be done using the following code:

import torch
import torchaudio

from whisper_medusa import WhisperMedusaModel
from transformers import WhisperProcessor

model_name = "aiola/whisper-medusa-v1"
model = WhisperMedusaModel.from_pretrained(model_name)
processor = WhisperProcessor.from_pretrained(model_name)

path_to_audio = "path/to/audio.wav"
SAMPLING_RATE = 16000
language = "en"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

input_speech, sr = torchaudio.load(path_to_audio)
if input_speech.shape[0] > 1:  # If stereo, average the channels
    input_speech = input_speech.mean(dim=0, keepdim=True)

if sr != SAMPLING_RATE:
    input_speech = torchaudio.transforms.Resample(sr, SAMPLING_RATE)(input_speech)

input_features = processor(input_speech.squeeze(), return_tensors="pt", sampling_rate=SAMPLING_RATE).input_features
input_features = input_features.to(device)

model = model.to(device)
model_output = model.generate(
    input_features,
    language=language,
)
predict_ids = model_output[0]
pred = processor.decode(predict_ids, skip_special_tokens=True)
print(pred)

Model evaluation

To evaluate the model we assume a csv file with the following columns:

  • audio: path to the audio file.
  • sentence: the corresponding transcript.
  • language: the language of the audio file.

Then run the following command:

python whisper_medusa/eval_whisper_medusa.py \
--model-name /path/to/model \
--data-path /path/to/data \
--out-file-path /path/to/output \
--language en

arguments description:

  • model-name: path to local model / huggingface hub.
  • data-path: path to the data.
  • out-file-path: path to the output file.
  • language: default language fallback.

Citations

@article{cai2024medusa,
  title={Medusa: Simple llm inference acceleration framework with multiple decoding heads},
  author={Cai, Tianle and Li, Yuhong and Geng, Zhengyang and Peng, Hongwu and Lee, Jason D and Chen, Deming and Dao, Tri},
  journal={arXiv preprint arXiv:2401.10774},
  year={2024}
}