/auto_avsr

Auto-AVSR: Lip-Reading Sentences Project

Primary LanguagePythonApache License 2.0Apache-2.0

Auto-AVSR: Lip-Reading Sentences Project

PWC

Update

2023-07-26: We released the implementation of Real-Time AV-ASR.

Introduction

This repository is an open-sourced framework for speech recognition, with a primary focus on visual speech (lip-reading). It is designed for end-to-end training, aiming to deliver state-of-the-art models and enable reproducibility on audio-visual speech benchmarks.

By using this repository, you can achieve a word error rate (WER) of 20.3% for visual speech recognition (VSR) and 1.0% for audio speech recognition (ASR) on LRS3.

Setup

  1. Set up environment:
conda create -y -n auto_avsr python=3.8
conda activate auto_avsr
  1. Clone repository:
git clone https://github.com/mpc001/auto_avsr
cd auto_avsr
  1. Install fairseq within the repository:
git clone https://github.com/pytorch/fairseq
cd fairseq
pip install --editable ./
cd ..
  1. Install PyTorch (tested pytorch version: v2.0.1) and other packages:
pip install torch torchvision torchaudio
pip install pytorch-lightning==1.5.10
pip install sentencepiece
pip install av
pip install hydra-core --upgrade
  1. Prepare the dataset. See the instructions in the preparation folder.

Training

python main.py exp_dir=[exp_dir] \
               exp_name=[exp_name] \
               data.modality=[modality] \
               data.dataset.root_dir=[root_dir] \
               data.dataset.train_file=[train_file] \
               trainer.num_nodes=[num_nodes] \
Required arguments
  • exp_dir: Directory to save checkpoints and logs to.
  • exp_name: Experiment name. Location of checkpoints is [exp_dir]/[exp_name].
  • data.modality: Type of input modality, valid values: video and audio.
  • data.dataset.root_dir: Root directory of preprocessed dataset, default: null.
  • data.dataset.train_file: Filename of training label list, default: lrs3_train_transcript_lengths_seg24s.csv.
  • trainer.num_nodes: Number of machines used, default: 1.
  • trainer.resume_from_checkpoint: Path of the checkpoint from which training is resumed, default: null.
Optional arguments
  • data.dataset.val_file: Filename of validation label list, default: lrs3_test_transcript_lengths_seg24s.csv.
  • pretrained_model_path: Path to the pre-trained model, default: null.
  • transfer_frontend Flag to load the weights of front-end module, works with pretrained_model_path.
  • transfer_encoder Flag to load the weights of encoder, works with pretrained_model_path.
  • trainer.max_epochs: Number of epochs, default: 75.
  • trainer.gpus: Number of GPUs to train on on each machine, default: -1, which use all gpus.
  • data.max_frames: Maximal number of frames in a batch, default: 1800.
  • optimizer.lr: Learning rate, default: 0.001.
Note
  • For lrs3, start by training from scratch on a subset (23h, max duration=4 seconds) at a learning rate of 0.0002 (see model-zoo). Then fine-tune on the full set with a learning rate of 0.001. A script for subset creation is available here. For training new datasets, please refer to instruction.
  • If you want to monitor the training process, customise logger within pytorch_lightning.Trainer().
  • To maximize resource utilization, set data.max_frames to the largest to fit into your GPU memory.

Testing

python eval.py data.modality=[modality] \
               data.dataset.root_dir=[root_dir] \
               data.dataset.test_file=[test_file] \
               pretrained_model_path=[pretrained_model_path] \
Required arguments
  • data.modality: Type of input modality, valid values: video and audio.
  • data.dataset.root_dir: Root directory of preprocessed dataset, default: null.
  • data.dataset.test_file: Filename of testing label list, default: lrs3_test_transcript_lengths_seg24s.csv.
  • pretrained_model_path: Path to the pre-trained model, set to [exp_dir]/[exp_name]/model_avg_10.pth, default: null.
Optional arguments
  • decode.snr_target=[snr_target]: Level of signal-to-noise ratio (SNR), default: 999999.

Demo

Want to see how our asr/vsr model performs on your audio/video? Just run this command:

python demo.py  data.modality=[modality] \
                pretrained_model_path=[pretrained_model_path] \
                file_path=[file_path]
Required arguments
  • data.modality: Type of input modality, valid values: video and audio.
  • pretrained_model_path: Path to the pre-trained model.
  • file_path: Path to the file for testing.

Model zoo

We provide models for lrs3, and plan to release more (including av-asr) soon.

LRS3
Model Training data (h) WER [%] MD5
vsr_trlrs3_23h_base.pth 23 96.6 50c88
vsr_trlrs3_base.pth 438 36.7 ea3ec
vsr_trlrs3vox2_base.pth 1759 25.0 0a126
vsr_trlrwlrs2lrs3vox2avsp_base.pth 3448 20.3 a896f
asr_trlrs3_23h_base.pth 23 72.5 87d45
asr_trlrs3_base.pth 438 2.04 4fa87
asr_trlrs3vox2_base.pth 1759 1.07 7beab
asr_trlrwlrs2lrs3vox2avsp_base.pth 3448 0.99 dc759

Citation

If you find this repository helpful, please consider citing our work:

@inproceedings{ma2023auto,
  author={Ma, Pingchuan and Haliassos, Alexandros and Fernandez-Lopez, Adriana and Chen, Honglie and Petridis, Stavros and Pantic, Maja},
  booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
  title={Auto-AVSR: Audio-Visual Speech Recognition with Automatic Labels},
  year={2023},
  pages={1-5},
  doi={10.1109/ICASSP49357.2023.10096889}
}

Acknowledgement

This repository is built using the espnet, fairseq, raven and avhubert repositories.

License

Code is Apache 2.0 licensed. The pre-trained models provided in this repository may have their own licenses or terms and conditions derived from the dataset used for training.

Contact

Contributions are welcome; feel free to create a PR or email me:

[Pingchuan Ma](pingchuan.ma16[at]imperial.ac.uk)