Multi-teacher Knowledge Distillation for CTC/Att models

This is the implementation of multi-teacher distillation methods to joint ctc-attention end-to-end ASR systems. The proposed approaches integrate the error rate metric to the teacher selection rather than solely focusing on the observed losses. This way, we directly distillate and optimize the student toward the relevant metric for speech recognition. For details please refer to: https://arxiv.org/abs/2005.09310.

Installation

  1. Please install the newest version of SpeechBrain.
  2. Make sure h5py is installed. Otherwise, run: pip install h5py.

Training steps

To speed up student distillation from multiple teachers, we separate the whole procedure into two parts: inference running on pre-trained teacher models, student distillation.

  1. Run inference on all teacher models

    This part run inference on all teacher models and store them on disk using save_teachers.py. It is only required that you setup the tea_models_dir variable corresponding to the path to a txt file. The latter txt file needs to contain a list of paths pointing to each teacher model.ckpt. We decided to work with a file so it can easily scale to hundreds of teachers. Hence, an example of this file is:

    results/tea0/save/model.ckpt
    results/tea1/save/model.ckpt
    results/tea2/save/model.ckpt
    results/tea3/save/model.ckpt
    results/tea4/save/model.ckpt
    results/tea5/save/model.ckpt
    results/tea6/save/model.ckpt
    results/tea7/save/model.ckpt
    results/tea8/save/model.ckpt
    results/tea9/save/model.ckpt
    

    Example:

    python save_teachers.py hparams/save_teachers.yaml --data_folder /path-to/data_folder --tea_models_dir /path-to/tea_model_paths.txt
    
  2. Student distillation

    This is the main part for distillation using train_kd.py. Here, the variable pretrain might be used to use a pre-trained teacher as the student. Note that if set to True, a path to the corresponding model.ckpt must be given in pretrain_st_dir. Also, tea_infer_dir is required, linking to the directory of teacher model inference results. Finally, note that the distillation must be trained on with the exact same input CSV files that are generated by save_teachers.py. This ensure that the distillation is perfectly linked to the generated teacher predictions! Diverging input CSV files might generate incompatible shape errors!

    Example:

    python train_kd.py hparams/train_kd.yaml --data_folder /path-to/data_folder --pretrain_st_dir /path-to/model_directory --tea_infer_dir /path-to/tea_infer_directory
    

Distillation strategies

There are four strategies in the current version that can be switched with the option strategy in hparams/train_kd.yaml.

  • average: average losses of teachers when doing distillation.
  • top-1: choosing the best teacher based on WER.
  • top-k: choosing the k best teachers based on WER if they have the same WER.
  • weighted: assigning weights to teachers based on WER.