Python Libraries
python==3.6.10
torch==1.4.0
kaldi-io==0.9.1
kaldi-python-io==1.0.4
- Install the python libraries listed in Requirements
- Install the Kaldi toolkit: https://github.com/kaldi-asr/kaldi/blob/master/INSTALL
- Download this repository. NOTE: Destination need not be inside Kaldi installation.
- Set the
voxcelebDir
variable inside pytorch_run.sh
Training features are expected in Kaldi nnet3 egs format, and read using the nnet3EgsDL
class defined in train_utils.py. The voxceleb recipe is provided in pytorch_run.sh to prepare them. Features for embedding extraction are expected in Kaldi matrix format, read using the kaldi_io library. Extracted embeddings are written in Kaldi vector format, similar to xvector.ark
.
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 train_xent.py <egsDir>
usage: train_xent.py [-h] [--local_rank LOCAL_RANK] [-modelType MODELTYPE]
[-featDim FEATDIM] [-resumeTraining RESUMETRAINING]
[-resumeModelDir RESUMEMODELDIR]
[-numArchives NUMARCHIVES] [-numSpkrs NUMSPKRS]
[-logStepSize LOGSTEPSIZE] [-batchSize BATCHSIZE]
[-numEgsPerArk NUMEGSPERARK]
[-preFetchRatio PREFETCHRATIO]
[-optimMomentum OPTIMMOMENTUM] [-baseLR BASELR]
[-maxLR MAXLR] [-numEpochs NUMEPOCHS]
[-noiseEps NOISEEPS] [-pDropMax PDROPMAX]
[-stepFrac STEPFRAC]
egsDir
positional arguments:
egsDir Directory with training archives
optional arguments:
-h, --help show this help message and exit
--local_rank LOCAL_RANK
-modelType MODELTYPE Refer train_utils.py
-featDim FEATDIM Frame-level feature dimension
-resumeTraining RESUMETRAINING
(1) Resume training, or (0) Train from scratch
-resumeModelDir RESUMEMODELDIR
Path containing training checkpoints
-numArchives NUMARCHIVES
Number of egs.*.ark files
-numSpkrs NUMSPKRS Number of output labels
-logStepSize LOGSTEPSIZE
Iterations per log
-batchSize BATCHSIZE Batch size
-numEgsPerArk NUMEGSPERARK
Number of training examples per egs file
-preFetchRatio PREFETCHRATIO
xbatchSize to fetch from dataloader
-optimMomentum OPTIMMOMENTUM
Optimizer momentum
-baseLR BASELR Initial LR
-maxLR MAXLR Maximum LR
-numEpochs NUMEPOCHS Number of training epochs
-noiseEps NOISEEPS Noise strength before pooling
-pDropMax PDROPMAX Maximum dropout probability
-stepFrac STEPFRAC Training iteration when dropout = pDropMax
egsDir
contains the nnet3 egs files.
usage: extract.py [-h] [-modelType MODELTYPE] [-numSpkrs NUMSPKRS]
modelDirectory featDir embeddingDir
positional arguments:
modelDirectory Directory containing the model checkpoints
featDir Directory containing features ready for extraction
embeddingDir Output directory
optional arguments:
-h, --help show this help message and exit
-modelType MODELTYPE Refer train_utils.py
-numSpkrs NUMSPKRS Number of output labels for model
The script pytorch_run.sh can be used to train embeddings on the voxceleb recipe on an end-to-end basis.
To reproduce voxceleb EER results with the pretrained model, follow the below steps.
NOTE: The voxceleb features must be prepared using prepare_feats_for_egs.sh
prior to evaluation.
- Download the model
- Extract
models/
andxvectors/
into the installation directory - Set the following variables in pytorch_run.sh:
modelDir=models/xvec_preTrained trainFeatDir=data/train_combined_no_sil trainXvecDir=xvectors/xvec_preTrained/train testFeatDir=data/voxceleb1_test_no_sil testXvecDir=xvectors/xvec_preTrained/test
- Extract embeddings and compute EER, minDCF. Set
stage=7
in pytorch_run.sh and execute:bash pytorch_run.sh
- Alternatively, pretrained PLDA model is available inside
xvectors/train
directory. Setstage=9
in pytorch_run.sh and execute:bash pytorch_run.sh
Place the audio files to diarize and their corresponding RTTM files in demo_wav/
and demo_rttm/
directories. Execute:
bash diarize.sh
Kaldi | pytorch_xvectors | |
---|---|---|
Vox1-test | 3.13 | 2.82 |
VOICES-dev | 10.30 | 8.59 |
NOTE: Clustering using https://github.com/tango4j/Auto-Tuning-Spectral-Clustering
Kaldi | pytorch_xvectors | |
---|---|---|
DIHARD2 dev (no collar, oracle #spk) | 26.97 | 27.50 |
DIHARD2 dev (no collar, est #spk) | 24.49 | 24.66 |
AMI dev+test (26 meetings, collar, oracle #spk) | 6.39 | 6.30 |
AMI dev+test (26 meetings, collar, est #spk) | 7.29 | 10.14 |