Implementation for the peper WhisperNER: Unified Open Named Entity and Speech Recognition. WhisperNER is a unified model for automatic speech recognition (ASR) and named entity recognition (NER), with zero-shot capabilities. The WhisperNER model is designed as a strong base model for the downstream task of ASR with NER, and can be fine-tuned on specific datasets for improved performance.
- Paper: WhisperNER: Unified Open Named Entity and Speech Recognition.
- Demo: Check out the demo here.
- Models:
- Datasets:
- Voxpopuli-NER-EN: A dataset for zero-shot NER evaluation based on the Voxpopuli dataset. The VoxPopuli Data is released under CC0 license, with European Parliament's legal disclaimer. (see European Parliament's legal notice for the raw data)
Start with creating a virtual environment and activating it:
conda create -n whisper-ner python=3.10 -y
conda activate whisper-ner
pip install torch==2.2.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118
Then install the package:
git clone https://github.com/aiola-lab/whisper-ner.git
cd whisper-ner
pip install -e .
Inference can be done using the following code:
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from whisper_ner.utils import audio_preprocess, prompt_preprocess
model_path = "aiola/whisper-ner-v1"
audio_file_path = "path/to/audio/file"
prompt = "person, company, location" # comma separated entity tags
# load model and processor from pre-trained
processor = WhisperProcessor.from_pretrained(model_path)
model = WhisperForConditionalGeneration.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# load audio file: user is responsible for loading the audio files themselves
input_features = audio_preprocess(audio_file_path, processor)
input_features = input_features.to(device)
prompt_ids = prompt_preprocess(prompt, processor)
prompt_ids = prompt_ids.to(device)
# generate token ids by running model forward sequentially
with torch.no_grad():
predicted_ids = model.generate(
input_features,
prompt_ids=prompt_ids,
generation_config=model.generation_config,
language="en",
)
# post-process token ids to text, remove prompt
transcription = processor.batch_decode(
predicted_ids, skip_special_tokens=True
)[0]
print(transcription)
If you find our work or this code to be useful in your own research, please consider citing the following paper:
@article{ayache2024whisperner,
title={WhisperNER: Unified Open Named Entity and Speech Recognition},
author={Ayache, Gil and Pirchi, Menachem and Navon, Aviv and Shamsian, Aviv and Hetz, Gill and Keshet, Joseph},
journal={arXiv preprint arXiv:2409.08107},
year={2024}
}