SDR loss sensitive to nan
sevagh opened this issue · 4 comments
Hello,
I'm trying to use SI-SDR and/or SD-SDR loss for a model for music source separation.
I'm working with the well-known open-unmix model (https://github.com/sigsep/open-unmix-pytorch).
The original model is as follows:
- Xmag = magnitude spectrogram of input waveform (mixed song)
- Ymag_hat = network prediction of magnitude spectrogram of the source being separated (one of drums, bass, vocals, other)
- Ymag = magnitude spectrogram of ground truth of the source being separated
- Loss = MSE(Ymag_hat, Ymag)
I wanted to see if I could see any differences by using SDR (which is actually the real evaluation metric of the full source separation task) within the training loop:
- X = complex spectrogram of input waveform (mixed song)
- Xmag = magnitude(X)
- Xphase = phase(X)
- Ymag_hat = network prediction of magnitude spectrogram of source being separated
- Ycomplex_hat = Ymag_hat * Xphase (combine source magnitude + mix phase for source complex spectrogram)
- y_hat = istft(Ycomplex_hat)
- Loss = auraloss.SISDR(y_hat, y), loss on SDR of waveforms
In the first iteration, the network's prediction is so bad that in a few places, the SDR value is NaN. After this, the gradients get in a bad state and the next prediction from the network is entirely nan.
Here's some print statements I inserted in the body of the SI-SDR code to help pinpoint the issue. What's being printed is:
- Input tensor (waveform)
- Output tensor (waveform from the neural network's predicted spectrogram)
- SI-SDR loss functions (printing each intermediate step before the final value)
The output below is for the first 2 epochs, showing how it goes from 2 nans in the SI-SDR loss function to all nans.
(umx-gpu) sevagh:umx-mr $ ./train.sh
Using GPU: True
Training Epoch: 0%| | 0/1000 [00:00<?, ?it/sy_hat, y shape: torch.Size([16, 2, 264600]), torch.Size([16, 2, 264600]) | 0/344 [00:00<?, ?it/s]
tensor([[[-2.1781e-03, -6.5221e-04, 1.1178e-03, ..., 1.3619e-03,
2.1249e-03, 1.0944e-02],
[-1.2133e-02, -8.7155e-03, -6.3351e-03, ..., 1.4447e-02,
2.1405e-02, 2.3420e-02]],
[[ 2.0096e-01, 1.8979e-01, 2.0478e-01, ..., -1.1957e-02,
-4.6632e-03, -1.6420e-03],
[ 1.4364e-01, 1.3714e-01, 1.5252e-01, ..., -2.3354e-02,
-7.9426e-03, 4.1423e-03]],
[[-1.7183e-02, -9.4009e-03, 9.0017e-05, ..., -4.4771e-02,
-4.8189e-02, -5.0478e-02],
[-1.9377e-02, -1.3426e-02, 1.5882e-03, ..., -2.9326e-02,
-2.8929e-02, -2.5298e-02]],
...,
[[-3.7857e-03, -4.0909e-03, -3.7857e-03, ..., -9.3050e-05,
3.6471e-04, 7.9196e-04],
[-1.6300e-03, -1.5385e-03, -1.9657e-03, ..., 1.0948e-04,
-5.0087e-04, -1.6518e-04]],
[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]],
[[ 1.0614e-01, 9.7930e-02, 8.7920e-02, ..., -1.0565e-01,
-4.4679e-02, 3.5308e-02],
[ 1.0367e-01, 9.5126e-02, 8.4048e-02, ..., -1.2213e-01,
-6.5000e-02, 1.4834e-02]]], device='cuda:0')
tensor([[[ 0.1386, 0.1427, 0.1455, ..., 0.1328, 0.0626, 0.0005],
[ 0.0152, 0.0212, 0.0245, ..., 0.1702, 0.1278, 0.0885]],
[[ 0.3065, 0.2870, 0.2915, ..., -0.1038, -0.0997, -0.0963],
[ 0.1356, 0.1494, 0.1805, ..., -0.1743, -0.1608, -0.1511]],
[[ 0.0508, 0.0478, 0.0527, ..., 0.2344, 0.2701, 0.2726],
[-0.1404, -0.1422, -0.1259, ..., 0.3193, 0.3360, 0.3355]],
...,
[[-0.0528, -0.0437, -0.0349, ..., -0.1331, -0.1366, -0.1403],
[-0.2515, -0.2280, -0.2245, ..., 0.0535, 0.0505, 0.0503]],
[[-0.0694, -0.0605, -0.0499, ..., -0.4447, -0.4446, -0.4433],
[ 0.0713, 0.0818, 0.0828, ..., -0.0331, -0.0137, 0.0073]],
[[ 0.3649, 0.3720, 0.3841, ..., -0.1016, -0.0735, -0.0239],
[ 0.3345, 0.3408, 0.3497, ..., -0.1535, -0.0997, -0.0270]]],
device='cuda:0', grad_fn=<SubBackward0>)
losses: tensor([[-13.2343, -12.9311],
[ 1.0036, -0.8949],
[ -6.3418, -7.2471],
[-53.9541, -59.8263],
[ 3.6331, 1.7483],
[-21.9817, -20.3990],
[ nan, nan],
[ -6.9908, -7.9694],
[ -2.5666, -4.8039],
[-35.9928, -35.8143],
[-18.4182, -16.2987],
[ -4.5252, -9.3269],
[ -4.0607, -5.2343],
[ -3.2994, -1.3853],
[ nan, nan],
[ -2.2624, 0.1392]], device='cuda:0', grad_fn=<MulBackward0>)
losses: -10.913591384887695
loss: 10.913591384887695
y_hat, y shape: torch.Size([16, 2, 264600]), torch.Size([16, 2, 264600]) | 1/344 [00:02<14:28, 2.53s/it]
tensor([[[ 0.0071, -0.0640, -0.0851, ..., -0.0198, -0.0260, -0.0203],
[ 0.0515, -0.0161, -0.0546, ..., 0.0040, 0.0071, 0.0099]],
[[ 0.0042, 0.0053, 0.0014, ..., 0.0006, 0.0010, 0.0017],
[-0.0004, -0.0029, -0.0021, ..., -0.0024, -0.0026, -0.0041]],
[[-0.0011, -0.0011, -0.0011, ..., 0.0019, 0.0021, 0.0017],
[-0.0004, -0.0005, -0.0005, ..., 0.0010, 0.0012, 0.0012]],
...,
[[ 0.1997, 0.1637, 0.1673, ..., -0.0564, -0.0582, -0.0581],
[ 0.2007, 0.1652, 0.1685, ..., -0.0525, -0.0540, -0.0547]],
[[ 0.0412, 0.0584, 0.0503, ..., -0.0031, 0.0207, 0.0357],
[ 0.0418, 0.0590, 0.0510, ..., 0.0010, 0.0250, 0.0392]],
[[ 0.0338, 0.0072, -0.0184, ..., -0.0112, -0.0148, -0.0131],
[ 0.0114, 0.0253, 0.0207, ..., -0.0008, -0.0031, -0.0061]]],
device='cuda:0')
tensor([[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
...,
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]],
[[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]]], device='cuda:0',
grad_fn=<SubBackward0>)
losses: tensor([[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan],
[nan, nan]], device='cuda:0', grad_fn=<MulBackward0>)
losses: nan
loss: nan
Do you have any suggestions? I tried "torch.nan_to_num" (without any arguments, so using the default substitutions: https://pytorch.org/docs/stable/generated/torch.nan_to_num.html), and it's basically the exact same behavior (except the loss is 0 instead of nan).
Hey, thanks for checking out auraloss!
I have seen this issue before, I traced it to an instability issue in the computation of SI-SDR. My guess is that this is caused when the target contains values of all zeros. Here is a minimal code example to demonstrate that:
import torch
import auraloss
x = torch.rand(4, 2, 100)
y = torch.zeros(4, 2, 100)
sisdr_loss = auraloss.time.SISDRLoss()
print(sisdr_loss(x, y)) # nan
print(sisdr_loss(y, x)) # 80 dB
When the target y
is a vector of zeros we get a NaN for the loss. We would expect the loss to be symmetric, so setting x or y as the target should produce the same 80dB error (The error is bounded to 80dB due to the choice of eps=1e-8
).
Ideally you would not have a vector of all zeros when training, but the code ought to handle this situation. I have made a quick fix for this. You can try it by installing the latest auraloss version through the GitHub repo.
pip install git+https://github.com/csteinmetz1/auraloss.git
Let me know if this works for you. I will push a new release shortly afterwards.
Thanks for the reply - I'll be able to test this within a few days (waiting for a replacement GPU, currently).
I'm definitely having a better time with it @csteinmetz1 :
I installed the latest from github like you suggested.
Great! Thanks for testing it out.