/FS-EEND

The official Pytorch implementation of "Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors". [ICASSP 2024]

Primary LanguagePythonMIT LicenseMIT

FS-EEND

The official Pytorch implementation of "Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors".

This work is accepted by ICASSP 2024.

version version python python

Paper 🤩 | Issues 😅 | Lab 🙉 | Contact 😘

Introduction

This work proposes a frame-wise online/streaming end-to-end neural diarization (FS-EEND) method in a frame-in-frame-out fashion. To frame-wisely detect a flexible number of speakers and extract/update their corresponding attractors, we propose to leverage a causal speaker embedding encoder and an online non-autoregressive self-attention-based attractor decoder. A look-ahead mechanism is adopted to allow leveraging some future frames for effectively detecting new speakers in real time and adaptively updating speaker attractors.

The proposed FS-EEND architecture

Get started

  1. Clone the FS-EEND codes by:
git clone https://github.com/Audio-WestlakeU/FS-EEND.git
  1. Prepare kaldi-style data by referring to here. Modify conf/xxx.yaml according to your own paths.

  2. Start training on simulated data by

python train_dia.py --configs conf/spk_onl_tfm_enc_dec_nonautoreg.yaml --gpus YOUR_DEVICE_ID
  1. Modify your pretrained model path in conf/spk_onl_tfm_enc_dec_nonautoreg_callhome.yaml.
  2. Finetune on CALLHOME data by
python train_dia_fintn_ch.py --configs conf/spk_onl_tfm_enc_dec_nonautoreg_callhome.yaml --gpus YOUR_DEVICE_ID
  1. Inference by (# modify your own path to save predictions in test_step in train/oln_tfm_enc_decxxx.py.)
python train_diaxxx.py --configs conf/xxx_infer.yaml --gpus YOUR_DEVICE_ID --test_from_folder YOUR_CKPT_SAVE_DIR
  1. Evaluation
# generate speech activity probability (diarization results)
cd visualize
python gen_h5_output.py

#calculate DERs
python metrics.py --configs conf/xxx_infer.yaml

Performance

Please note we use Switchboard Cellular (Part 1 and 2) and 2005-2008 NIST Speaker Recognition Evaluation (SRE) to generate simulated data (including 4054 speakers).

Dataset DER(%) ckpt
Simu1spk 0.6 simu_avg_41_50epo.ckpt
Simu2spk 4.3 same as above
Simu3spk 9.8 same as above
Simu4spk 14.7 same as above
CH2spk 10.0 ch_avg_91_100epo.ckpt
CH3spk 15.3 same as above
CH4spk 21.8 same as above

The ckpts are the average of model parameters for the last 10 epochs.

If you want to check the performance of ckpt on CALLHOME:

python train_dia_fintn_ch.py --configs conf/spk_onl_tfm_enc_dec_nonautoreg_callhome_infer.yaml --gpus YOUR_DEVICE_ID, --test_from_folder YOUR_CKPT_SAVE_DIR

Note the modification of the code in train_dia_fintn_ch.py

ckpts = [x for x in all_files if (".ckpt" in x) and ("epoch" in x) and int(x.split("=")[1].split("-")[0])>=configs["log"]["start_epoch"] and int(x.split("=")[1].split("-")[0])<=configs["log"]["end_epoch"]]

state_dict = torch.load(test_folder + "/" + c, map_location="cpu")["state_dict"]

to

ckpts = [x for x in all_files if (".ckpt" in x)]

state_dict = torch.load(test_folder + "/" + c, map_location="cpu")

Reference code

Citation

If you want to cite this paper:

@misc{liang2023framewise,
      title={Frame-wise streaming end-to-end speaker diarization with non-autoregressive self-attention-based attractors}, 
      author={Di Liang and Nian Shao and Xiaofei Li},
      year={2023},
      eprint={2309.13916},
      archivePrefix={arXiv},
      primaryClass={eess.AS}
}