/NASE

Code for paper "Noise-aware Speech Enhancement using Diffusion Probabilistic Model"

Primary LanguagePythonMIT LicenseMIT

Noise-aware Speech Enhancement using Diffusion Probabilistic Model

This repository contains the official PyTorch implementations for our paper:

Our code is based on prior work SGMSE+.

Installation

  • Create a new virtual environment with Python 3.8 (we have not tested other Python versions, but they may work).
  • Install the package dependencies via pip install -r requirements.txt.
  • If using W&B logging (default):
    • Set up a wandb.ai account
    • Log in via wandb login before running our code.
  • If not using W&B logging:
    • Pass the option --no_wandb to train.py.
    • Your logs will be stored as local TensorBoard logs. Run tensorboard --logdir logs/ to see them.

Pretrained checkpoints

Usage:

  • For resuming training, you can use the --resume_from_checkpoint option of train.py.
  • For evaluating these checkpoints, use the --ckpt option of enhancement.py (see section Evaluation below).

Training

Training is done by executing train.py. A minimal running example with default settings can be run with:

python train.py --base_dir <your_base_dir> --inject_type <inject_type> --pretrain_class_model <pretrained_beats>

where your_base_dir should be a path to a folder containing subdirectories train/ and valid/ (optionally test/ as well). Each subdirectory must itself have two subdirectories clean/ and noisy/, with the same filenames present in both. We currently only support training with .wav files. inject_type should be chosen from ["addition", "concat", "cross-attention"]. pretrained_beats should be the path to pre-trained BEATs.

The full command is also included in train.sh. To see all available training options, run python train.py --help.

Evaluation

To evaluate on a test set, run

python enhancement.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir> --ckpt <path_to_model_checkpoint> --pretrain_class_model <pretrained_beats>

to generate the enhanced .wav files, and subsequently run

python calc_metrics.py --test_dir <your_test_dir> --enhanced_dir <your_enhanced_dir>

to calculate and output the instrumental metrics.

Both scripts should receive the same --test_dir and --enhanced_dir parameters. The --cpkt parameter of enhancement.py should be the path to a trained model checkpoint, as stored by the logger in logs/. The --pretrain_class_model should be the path to pre-trained BEATs.

You may refer to our full commands included in enhancement.sh and calc_metrics.sh.

Citations

We kindly hope you can cite our paper in your publication when using our research or code:

@inproceedings{hu2024noise,
  title={Noise-aware Speech Enhancement using Diffusion Probabilistic Model}, 
  author={Hu, Yuchen and Chen, Chen and Li, Ruizhe and Zhu, Qiushi and Chng, Eng Siong},
  booktitle={INTERSPEECH},
  year={2024}
}