Official implementation of MelHuBERT

Primary LanguagePythonMIT LicenseMIT

MelHuBERT: A simplified HuBERT on Mel spectrogram

This is the official implementation of ASRU 2023 accepted paper.

Paper link: https://arxiv.org/abs/2211.09944

Paper introduction video: https://www.youtube.com/watch?v=S_t2TROKu6o

MelHuBERT, is able to achieve favorable performance on phone recognition, speaker identification, and automatic speech recognition against HuBERT, while saving 31.2% of the pretraining time, or equivalently 33.5% MACs per one second speech.



pip install -r requirement.txt

Data Preparing

First, please download dataset here, and unzip the dataset.

Then, please execute the following command to prepare log Mel feature and paired cluster labels (K-means on log Mel feature)

bash preprocess.sh [DATASET_DIR] [OUT_DIR]

Then, please adjust datarc.sets in ./config/config_runner_20ms.yaml and ./config/config_runner_10ms.yaml to [ OUT_DIR/libri-360-data-cluster-pair.csv ]

The mean and std of LibriSpeech 360 hours is saved at OUT_DIR/mean-std.npy (You won't need it during pre-training, but you might need it when fine-tuning on downstream.)

Pre-training MelHuBERT from scratch

Execute the following command to pretrain MelHuBERT from scratch with default configuration

  • 20 ms frame period:
python3 train.py -f 20 -g ./config/config_model_20ms.yaml -c ./config/config_runner_20ms.yaml -n EXP_DIR_PATH 
  • 10 ms frame period:
python3 train.py -f 10 -g ./config/config_model_10ms.yaml -c ./config/config_runner_10ms.yaml -n EXP_DIR_PATH 

-f: frame period
-g: Model config
-c: Runner config
-n: The model checkpoints, log file, and the pre-training config you used will be saved at this directory

Pretrained Models

Warning: Due to computational resource limitations, these MelHuBERT models were trained with a batch size of 32. Therefore, they cannot be fairly compared with fairseq's HuBERT Base, which was trained with much larger batch size.

Extracting feature

Please execute the following command to extract feature from two example waveforms

python3 extract_feature.py -c [CHECKPOINT] -f [FRAME_PERIOD]

-c: Model checkpoint path -f: Choice from 20 or 10 (ms)


Our implementation of pre-training interface is based on S3PRL toolkit