give nan results when use pytorch version for some input
quancs opened this issue · 4 comments
^^ Hello, I found fast_bss_eval
(version 0.1.0) sometimes gives NaN results.
The test code:
import numpy as np
import torch
from mir_eval.separation import bss_eval_sources
import fast_bss_eval
x = np.load('debug.npz')
preds = torch.tensor(x['preds'])
target = torch.tensor(x['target'])
print(preds.shape, target.shape)
sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(target, preds)
print(sdr)
sdr, sir, sar, perm = fast_bss_eval.bss_eval_sources(target.numpy(), preds.numpy())
print(sdr)
sdr,_,_,_ = bss_eval_sources(target.numpy(), preds.numpy(), False)
print(sdr)
the results:
torch.Size([2, 64000]) torch.Size([2, 64000])
tensor([-2.6815, nan])
[-2.6815615 44.575493 ]
[-2.68156071 44.58523729]
The data and the debug code are all zipped in the debug.zip
One thing that could explain it is that mir_eval will use float64 and torch float32 by default. Your SIR is 44 dB, which is very high, so that the denominator may become zero with float32.
I have checked that this seems to be an issue with the torch (cpu) float32 sub-routines, probably solve I would say. Interestingly, the result is correct when running on gpu. Numpy with float32 algo gives the correct value.
Nevertheless, it is not desirable to get nan
values in the output, so I have modified the code so that it returns inf
instead in this case.
inf
values are produced in case of extremely low or high sdr/sir/sar. To avoid getting such values, one can provide the clamp_db
argument to bss_eval_sources
which will saturate the output values to [-clamp_db, clamp_db].
Due to limited numerical precision, values over under -40 dB or above 40 dB may not be relevant. In that case, set clamp_db=True
.
I have added a test file tests/test_issue_5.py
based on your data and fixed this in PR #6 . I will release a new version as soon as CI is done.
$ python ./tests/test_issue_5.py
torch.Size([2, 64000]) torch.Size([2, 64000])
torch torch.float32
tensor([-2.6815, inf]) torch.float32
torch torch.float64
tensor([-2.6816, 44.5852], dtype=torch.float64) torch.float64
torch/cuda torch.float32
tensor([-2.6816, 44.6203], device='cuda:0') torch.float32
torch/cuda torch.float64
tensor([-2.6816, 44.5852], device='cuda:0', dtype=torch.float64) torch.float64
numpy float32
[-2.6815615 44.575493 ] float32
numpy float64
[-2.68156071 44.58523729] float64
mir_eval float32
[-2.68156071 44.58523729] float64
mir_eval float64
[-2.68156071 44.58523729] float64