This repository contains the official PyTorch implementations for our paper:
- Yuchen Hu, Chen Chen, Ruizhe Li, Qiushi Zhu, Eng Siong Chng. "Noise-aware Speech Enhancement using Diffusion Probabilistic Model".
Our code is based on prior work SGMSE+.
- Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work).
- Install the package dependencies via
pip install -r requirements.txt
. - If using W&B logging (default):
- Set up a wandb.ai account
- Log in via
wandb login
before running our code.
- If not using W&B logging:
- Pass the option
--no_wandb
totrain.py
. - Your logs will be stored as local TensorBoard logs. Run
tensorboard --logdir logs/
to see them.
- Pass the option
- We release pretrained checkpoint for the model trained on VoiceBank-DEMAND, as in the paper.
- We also provide testing samples before and after NASE processing for comparison.
Usage:
- For resuming training, you can use the
--resume_from_checkpoint
option oftrain.py
. - For evaluating these checkpoints, use the
--ckpt
option ofenhancement.py
(see section Evaluation below).
Training is done by executing train.py
. A minimal running example with default settings can be run with:
python train.py --base_dir <your_base_dir> --inject_type <inject_type> --pretrain_class_model <pretrained_beats>
where your_base_dir
should be a path to a folder containing subdirectories train/
and valid/
(optionally test/
as well). Each subdirectory must itself have two subdirectories clean/
and noisy/
, with the same filenames present in both. We currently only support training with .wav
files.
inject_type
should be chosen from ["addition", "concat", "cross-attention"].
pretrained_beats
should be the path to pre-trained BEATs.
The full command is also included in train.sh
.
To see all available training options, run python train.py --help
.
To evaluate on a test set, run
python enhancement.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir> --ckpt <path_to_model_checkpoint> --pretrain_class_model <pretrained_beats>
to generate the enhanced .wav files, and subsequently run
python calc_metrics.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir>
to calculate and output the instrumental metrics.
Both scripts should receive the same --test_dir
and --enhanced_dir
parameters.
The --cpkt
parameter of enhancement.py
should be the path to a trained model checkpoint, as stored by the logger in logs/
.
The --pretrain_class_model
should be the path to pre-trained BEATs.
You may refer to our full commands included in enhancement.sh
and calc_metrics.sh
.
We kindly hope you can cite our paper in your publication when using our research or code:
@inproceedings{hu2024noise,
title={Noise-aware Speech Enhancement using Diffusion Probabilistic Model},
author={Hu, Yuchen and Chen, Chen and Li, Ruizhe and Zhu, Qiushi and Chng, Eng Siong},
booktitle={INTERSPEECH},
year={2024}
}