Recycle-and-Distill: Universal Compression Strategy for Transformer-based Speech SSL Models with Attention Map Reusing and Masking Distillation, INTERSPEECH 2023.
Kangwook Jang*,
Sungnyun Kim*,
Se-Young Yun, Hoirin Kim
* equal contribution
- Attention Map Reusing: Reuse previous layer's attention map to remove key & query parameters in Transformer
- Masking Distillation: Masking distillation treating masked frames and unmasked frames separately
- Parameters and MACs of ARMHuBERT have decreased to 28% and 30% of the teacher, HuBERT Base, respectively.
- ARMHuBERT achieves PER of 7.72%, WER of 9.96% on the SUPERB benchmark in an E2E distillation manner.
📌 Check out our model's performance in SUPERB Leaderboard!
For our model's checkpoints, go check this link!
Model name | Parameters | Teacher | Training dataset | Link |
---|---|---|---|---|
ARMHuBERT-960h | 26.45M | HuBERT | LibriSpeech-960h | HF Model |
ARMHuBERT-S-100h | 22.39M | HuBERT | LibriSpeech-100h | HF Model |
ARMHuBERT-S-960h | 22.39M | HuBERT | LibriSpeech-960h | HF Model |
ARMwavLM-S-100h | 22.39M | wavLM | LibriSpeech-100h | HF Model |
ARMwavLM-S-960h | 22.39M | wavLM | LibriSpeech-960h | HF Model |
MaskHuBERT-960h | 26.64M | HuBERT | LibriSpeech-960h | HF Model |
Install the necessary packages with:
$ pip install -r requirements.txt
-
Download the teacher model checkpoint to perform knowledge distillation, and place it under the root path,
./
. -
Download the LibriSpeech dataset.
- For 100h distillation, download
train-clean-100
- For 960h distillation, download whole dataset,
train-clean-100
,train-clean-360
,train-other-500
- For validation, download
dev-clean
- You can validate your model with test clean other either. In this case, please download
test-clean
, and modifyself.eval_data
intrain.py
file.
- You can validate your model with test clean other either. In this case, please download
- For 100h distillation, download
-
Modify the configuration file in
./conf/[model_name]/[config].yaml
.- For example, the configuration file
./conf/armhubert/armhubert-960.yaml
contains all the settings for reproducing ARMHuBERT on LibriSpeech 960h dataset. - Set the path to the teacher model checkpoint at
teacher_model
, and the root path to the LibriSpeech dataset atlibri_root
.
- For example, the configuration file
-
Then, run the following command:
python train.py -c ./conf/[model_name]/[config].yaml
For ARMHuBERT,
python train.py -c ./conf/armhubert/armhubert-960.yaml
After training, the model checkpoints and the corresponding configuration file will be created at ./results/pretrain/
.
-
If you don't feel like training your model, feel free to use our checkpoints.
-
Clone and install the S3PRL toolkit with
pip install -e ".[all]"
(dev mode). -
Copy the entire
./models/[model_name]
folder into<s3prl root>/s3prl/upstream/
. -
Please add upstream importing line in
<s3prl root>/s3prl/hub.py
.from s3prl.upstream.[model_name].hubconf import *
For ARMHuBERT,
from s3prl.upstream.armhubert.hubconf import *
-
Please change each config file of s3prl downstream tasks as follows.
- Uncomment learning rate scheduler
- Learning rate scaled to 10x in spekaer identification (SID) task
-
Run the following command to fine-tune the ARMHuBERT model.
For automatic speech recognition (ASR) as an example:
python run_downstream.py \ -m train \ -n ARMHuBERT-ASR \ # You can set your exp name whatever you want -u armhubert \ -d asr \ -k <path to .ckpt file in <git root>/results/pretrain/> \ -g <path to .yaml file in <git root>/results/pretrain/>
Note: Refer to the SUPERB docs for more information on usage details and data preparation.
We evaluate our student models on the SUPERB benchmark.
MaskHuBERT highly improves the performances in content- and semantics-related tasks. See PR, ASR, SF, and IC.
ARMHuBERT shows promising improvements when compared to MaskHuBERT in SF and SID tasks, exhibiting a similar level of performance in other tasks.
ARMHuBERT achieves a better overall score of 78.1 with less parameters than MaskHuBERT. This is an state-of-the-art performance for an end-to-end distillation approach such as Deep-versus-wide 12-L or FitHuBERT.
You can also check that our model works on other Transformer backbone model, wavLM, too.
We have only performed evaluation on HuBERT-based models, but this strategy can be performed identically on any speech model with a Transformer backbone. E.g. AST (Audio Spectrogram Transformer).
If you find this repo useful for your research, please consider citing our paper:
@article{jang2023recycleanddistill,
title={Recycle-and-Distill: Universal Compression Strategy for Transformer-based Speech SSL Models with Attention Map Reusing and Masking Distillation},
author={Kangwook Jang and Sungnyun Kim and Se-Young Yun and Hoirin Kim},
booktitle={Proc. INTERSPEECH 2023},
pages={316--320},
year={2023}
}
🎉 Update (Apr 12, 2024): Our new paper, STaR, has been selected as Best Student Paper in ICASSP 2024!
🎉 Check out our model's performance in SUPERB Leaderboard!
STaR: Distilling Speech Temporal Relation for Lightweight Speech Self-Supervised Learning Models, ICASSP 2024.
Kangwook Jang,
Sungnyun Kim,
Hoirin Kim
- Speech Temporal Relation (STaR): Distill the knowledge by focusing on the pairwise temporal relation between two speech frames.
- Temporal Gram Matrix (TGM): Propose Temporal Gram Matrix which aggregates channel information at two time steps.
- Layer-wise TGM: Distill the TGM for every Transformer layer
- Intra-layer TGM: Modify the TGM as computing the temporal relation between the input and output of a single Transformer layer.
- Incorporating two TGMs as the distillation objectives together, our student model STaRHuBERT (22M & 26M) shows the SOTA performance on the SUPERB benchmark with the metric of overall and generalizability scores.
- For further compression (9.39M & 14.1M), our approach shows the robust performance against degradation compares to other works.
For our model's checkpoints, please check the following links. All models are distilled from HuBERT base.
- STaRHuBERT-L (26.6M): ckpt, yaml
- STaRHuBERT (22.3M): ckpt, yaml
- STaRHuBERT-S (14.1M): ckpt, yaml
- STaRHuBERT-XS (9.39M): ckpt, yaml
We do not offer an official implementation code for distillation. Nevertheless, since STaRHuBERT is developed upon the backbone of ARMHuBERT, you can easily re-implement our apporach with this ARMHuBERT repository.
You can reproduce our model with given checkpoints. Please follow the steps. (This is almost the same as ARMHuBERT case.)
-
Clone and install the S3PRL toolkit with
pip install -e ".[all]"
(dev mode). -
Copy the entire
./models/starhubert
folder into<s3prl root>/s3prl/upstream/
. -
Please add upstream importing line in
<s3prl root>/s3prl/hub.py
.from s3prl.upstream.starhubert.hubconf import *
-
Please change each config file of s3prl downstream tasks as follows.
- Uncomment learning rate scheduler
- Learning rate scaled to 10x in spekaer identification (SID) task
-
Run the following command to fine-tune the ARMHuBERT model.
For automatic speech recognition (ASR) as an example:
python run_downstream.py \ -m train \ -n STaRHuBERT-ASR \ # You can set your exp name whatever you want -u starhubert \ -d asr \ -k <path to .ckpt file in <git root>/results/pretrain/> \ -g <path to .yaml file in <git root>/results/pretrain/>
Note: Refer to the SUPERB docs for more information on usage details and data preparation.
If you find this repo useful for your research, please consider citing our paper:
@inproceedings{jang2024star,
title={STaR: Distilling Speech Temporal Relation for Lightweight Speech Self-Supervised Learning Models},
author={Jang, Kangwook and Kim, Sungnyun and Kim, Hoirin},
booktitle={ICASSP 2024-2024 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={10721--10725},
year={2024},
organization={IEEE}
}
For any details or clarification, please reach out to
- Kangwook Jang: dnrrkdwkd12@kaist.ac.kr
- Sungnyun Kim: ksn4397@kaist.ac.kr