The official Pytorch implementation of "Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors".
This work is accepted by ICASSP 2024.
Paper 🤩 | Issues 😅 | Lab 🙉 | Contact 😘
This work proposes a frame-wise online/streaming end-to-end neural diarization (FS-EEND) method in a frame-in-frame-out fashion. To frame-wisely detect a flexible number of speakers and extract/update their corresponding attractors, we propose to leverage a causal speaker embedding encoder and an online non-autoregressive self-attention-based attractor decoder. A look-ahead mechanism is adopted to allow leveraging some future frames for effectively detecting new speakers in real time and adaptively updating speaker attractors.
- Clone the FS-EEND codes by:
git clone https://github.com/Audio-WestlakeU/FS-EEND.git
-
Prepare kaldi-style data by referring to here. Modify conf/xxx.yaml according to your own paths.
-
Start training on simulated data by
python train_dia.py --configs conf/spk_onl_tfm_enc_dec_nonautoreg.yaml --gpus YOUR_DEVICE_ID
- Modify your pretrained model path in conf/spk_onl_tfm_enc_dec_nonautoreg_callhome.yaml.
- Finetune on CALLHOME data by
python train_dia_fintn_ch.py --configs conf/spk_onl_tfm_enc_dec_nonautoreg_callhome.yaml --gpus YOUR_DEVICE_ID
- Inference by (# modify your own path to save predictions in test_step in train/oln_tfm_enc_decxxx.py.)
python train_diaxxx.py --configs conf/xxx_infer.yaml --gpus YOUR_DEVICE_ID --test_from_folder YOUR_CKPT_SAVE_DIR
- Evaluation
# generate speech activity probability (diarization results)
cd visualize
python gen_h5_output.py
#calculate DERs
python metrics.py --configs conf/xxx_infer.yaml
Please note we use Switchboard Cellular (Part 1 and 2) and 2005-2008 NIST Speaker Recognition Evaluation (SRE) to generate simulated data (including 4054 speakers).
Dataset | DER(%) | ckpt |
---|---|---|
Simu1spk | 0.6 | simu_avg_41_50epo.ckpt |
Simu2spk | 4.3 | same as above |
Simu3spk | 9.8 | same as above |
Simu4spk | 14.7 | same as above |
CH2spk | 10.0 | ch_avg_91_100epo.ckpt |
CH3spk | 15.3 | same as above |
CH4spk | 21.8 | same as above |
The ckpts are the average of model parameters for the last 10 epochs.
If you want to check the performance of ckpt on CALLHOME:
python train_dia_fintn_ch.py --configs conf/spk_onl_tfm_enc_dec_nonautoreg_callhome_infer.yaml --gpus YOUR_DEVICE_ID, --test_from_folder YOUR_CKPT_SAVE_DIR
Note the modification of the code in train_dia_fintn_ch.py
ckpts = [x for x in all_files if (".ckpt" in x) and ("epoch" in x) and int(x.split("=")[1].split("-")[0])>=configs["log"]["start_epoch"] and int(x.split("=")[1].split("-")[0])<=configs["log"]["end_epoch"]]
state_dict = torch.load(test_folder + "/" + c, map_location="cpu")["state_dict"]
to
ckpts = [x for x in all_files if (".ckpt" in x)]
state_dict = torch.load(test_folder + "/" + c, map_location="cpu")
If you want to cite this paper:
@misc{liang2023framewise,
title={Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors},
author={Di Liang and Nian Shao and Xiaofei Li},
year={2023},
eprint={2309.13916},
archivePrefix={arXiv},
primaryClass={eess.AS}
}