This GitHub repository contains code related to the paper "Maximum Discrepancy Generative Regularization and Non-Negative Matrix Factorization for Single Channel Source Separation" written by Martin Ludvigsen and Markus Grasmair. The paper proposes a new approach to training NMF for single channel source separation (SCSS), using ideas from recent work on adversarial regularization functions. The code in this repository implements the proposed approach and can be used to reproduce the results presented in the paper, as well as to explore variations of the method and apply it to other datasets.
The main goal and novelty is to represent true data well with a non-negative basis, as well as adversarial data poorly.
In other words, we want
-
$U$ is a$m \times N$ matrix containing the true data stored column-wise. -
$\hat{U}$ is a$m \times \hat{N}$ matrix containing the adversarial data stored column-wise. -
$H$ is a$d \times N$ matrix containing the true weights. -
$\hat{H}$ is a$d \times \hat{N}$ matrix containing the adversarial weights.
MDNMF is fitted by solving
where
The parameter
The main application of this method is for single channel source separation problems, but can be applied to any inverse problem where the true signals can be reasonably represented with non-negative bases.
The code is implemented in Python, and the dependencies are the packages NumPy, Librosa and Pandas.
The datasets used in numerical experiments can be obtained as follows:
- The MNIST dataset is imported using the Python package Tensorflow/Keras.
- The LibriSpeech dataset can be obtained here: https://www.openslr.org/12. We only use the dev-clean part of the dataset.
- The Musan dataset can be obtained here: https://www.openslr.org/17/. We only use the noise part of the dataset.
The interface and usage is relatively similar to the scikit learn implementation of NMF and similar methods.
There are two main classes:
NMF
, which is an object that handles fitting of NMF bases for a single source.NMF_separation
, which is an object that handles fitting and separation for a specific source separation problem using NMF.
For example, to fit a standard NMF with data stored column-wise in
d = 32 # Number of basis vectors
# 100 epochs with batch size 500
nmf = NMF(d = d, batch_size = 500, epochs = 100)
# Standard fitting
nmf.fit_std(U)
We can then extract the basis/dictionary
W = nmf.W
H = nmf.H
U_reconstructed = np.dot(W, H)
We can alternatively recalculate
H = nmf.transform(U)
For adversarial fitting, we can do
nmf = NMF(d = d, batch_size = 500, epochs = 100, prob = "adv", tau_A = 0.1)
nmf.fit_adv(U, U_hat)
Coming soon.
Coming soon.