This is the implementation of multi-teacher distillation methods to joint ctc-attention end-to-end ASR systems. The proposed approaches integrate the error rate metric to the teacher selection rather than solely focusing on the observed losses. This way, we directly distillate and optimize the student toward the relevant metric for speech recognition. For details please refer to: https://arxiv.org/abs/2005.09310.
- Please install the newest version of SpeechBrain.
- Make sure h5py is installed. Otherwise, run: pip install h5py.
To speed up student distillation from multiple teachers, we separate the whole procedure into two parts: inference running on pre-trained teacher models, student distillation.
-
Run inference on all teacher models
This part run inference on all teacher models and store them on disk using
save_teachers.py
. It is only required that you setup thetea_models_dir
variable corresponding to the path to a txt file. The latter txt file needs to contain a list of paths pointing to each teacher model.ckpt. We decided to work with a file so it can easily scale to hundreds of teachers. Hence, an example of this file is:results/tea0/save/model.ckpt results/tea1/save/model.ckpt results/tea2/save/model.ckpt results/tea3/save/model.ckpt results/tea4/save/model.ckpt results/tea5/save/model.ckpt results/tea6/save/model.ckpt results/tea7/save/model.ckpt results/tea8/save/model.ckpt results/tea9/save/model.ckpt
Example:
python save_teachers.py hparams/save_teachers.yaml --data_folder /path-to/data_folder --tea_models_dir /path-to/tea_model_paths.txt
-
Student distillation
This is the main part for distillation using
train_kd.py
. Here, the variablepretrain
might be used to use a pre-trained teacher as the student. Note that if set toTrue
, a path to the correspondingmodel.ckpt
must be given inpretrain_st_dir
. Also,tea_infer_dir
is required, linking to the directory of teacher model inference results. Finally, note that the distillation must be trained on with the exact same input CSV files that are generated bysave_teachers.py
. This ensure that the distillation is perfectly linked to the generated teacher predictions! Diverging input CSV files might generate incompatible shape errors!Example:
python train_kd.py hparams/train_kd.yaml --data_folder /path-to/data_folder --pretrain_st_dir /path-to/model_directory --tea_infer_dir /path-to/tea_infer_directory
There are four strategies in the current version that can be switched with the option strategy
in hparams/train_kd.yaml
.
- average: average losses of teachers when doing distillation.
- top-1: choosing the best teacher based on WER.
- top-k: choosing the k best teachers based on WER if they have the same WER.
- weighted: assigning weights to teachers based on WER.