Implementation of Speech Enhancement GAN (SEGAN) by NNabla
Read me Japanese Ver. (日本語バージョンはこちら) -> Link
Original Paper
SEGAN: Speech Enhancement Generative Adversarial Network
https://arxiv.org/abs/1703.09452
- Python 3.6
- CUDA 10.0 & CuDNN 7.6
- Please choose the appropriate CUDA and CuDNN version to match your NNabla version
Please install the following packages with pip. (If necessary, install latest pip first.)
- nnabla (over v1.0.19)
- nnabla-ext-cuda (over v1.0.19)
- scipy
- numba
- joblib
- pyQT5
- pyqtgraph (after installing pyQT5)
- pypesq (see "install with pip" in offical site)
-
segan.py
This is main source code. Run this. -
data.py
This is for creating Batch Data. Before runnning, please download wav dataset as seen below. -
settings.py
This includes setting parameters. -
display.py
This includes some functions to display results.
-
Download
segan.py
,settings.py
,data.py
,display.py
and save them into the same directory. -
In the directory, make three folders
data
,pkl
,params
.data
folder : save wav data.pickle
folder : save pickled database "~.pkl".params
folder : save parameters including network models.
-
Download the following 4 dataset, and unzip them.
-
Move those unzipped 4 folders into
data
folder. -
Convert the sampling frequency of all the wav data to 16kHz. For example, this site is useful. After converting, you can delete the original wav data.
settings.py
is a parameter list including the setting parameters for learning & predicting.
Refer to the below when you want to know how to use the spectial paramters.
-
self.epoch_from
:
Number of starting Epoch when retraining. Ifself.epoch_from
> 0, restart learing after loading pre-trained models "discriminator_param_xxxx.h5" and "generator_param_xxxx.h5". The value ofself.epoch_from
should be corresponding to "xxxx".
Ifself.epoch_from
= 0, retraining does not work. -
self.model_save_cycle
:
Cycle of Epoch for saving network model. If "1", network model is saved for every 1 epoch.
If you are facing GPU Memory Stack Error, please try Half Precision Floating Point Mode which can downsize the calculation precision and thus reduce the memory usage. If you want to use, please run the following commands before defining the network.
ctx = get_extension_context('cudnn', device_id=args.device_id, type_config='half')
nn.set_default_context(ctx)
In segan.py
, this mode is enable by default.
Refer to "nnabla-ext-cuda" for more information.
- If training, set
Train=1
in main function ofsegan.py
. If predicting, setTrain=0
.
Train = 0
if Train:
# Training
nn.set_default_context(ctx)
train(args)
else:
# Test
#nn.set_default_context(ctx)
test(args)
pesq_score('clean.wav','output_segan.wav')
- Run
segan.py
.
If you run train(args)
function, the training dataset (xxxx.pkl) is created in pkl
at the beginning (for only the first time). And network model (xxxx.h5) is saved in params
folder by every cycle that you set by self.model_save_cycle
.
If you run test(args)
function, the test dataset (xxxx.pkl) is created in pkl
at the beginning (for only the first time). And the following wav data are generated as the results. PESQ value is also displayed.
- clean.wav : clean speech wav file
- noisy.wav : noisy speech wav file
- output.wav : reconstructed speech wav file