sparse method `slogdet` returns NaN for null matrix
aboucaud opened this issue · 6 comments
This issue makes the current tests fail on CI.
It corresponds to the specific test in test_sparse.py
:
X = np.array(
[
[[1, 2, 3], [4, 5, 6], [-1, 7, -2]],
[[1, 2, 3], [-4, -5, -6], [2, -3, 9]],
[[7, 8, 9], [5, -4, 6], [-3, -2, -1]],
]
)
assert_array_equal(det(0.0 * X), 0.0)
I traced it back to the second call of _block_det()
in the for-loop of slogdet()
: and can therefore be reproduced with
sparse = 0.0 * X
i = 1
N = 3
P = 3
print(_block_det(sparse, i, N, P))
Ping @dkirkby
Yes. indeed. and it's preventing the tests from running. It used to work though, so I suspect the problem comes from a new version of JAX breaking everything :-|
I'm gonna open an issue to upgrade the code to Jax 0.2 , hopefully that will also take care of this.
Ok I've investigated this a little bit further, what happens is that at some point in _block_det
some inf
gets multiplied by a 0, and that creates a Nan. This is in principle the correct behaviour, but leads to det( 0*X) = nan
I think for now, we could go ahead and skip this test for zero matrices. Let me know what you think, but I don't see a particular use case where we need logdet(0) to be -inf rather than nan.
I cannot think of a particular use case either. But having to handle NaNs in your calculations is quite painful. And you may by chance end up with zeros in a matrix, for witch you don't want your code to break. Especially if it did not break beforehand..
After testing with jax==0.2.0
and 0.2.3
I still get
_block_det(0.0 * X, 0, 3, 3)
# (DeviceArray(0., dtype=float32), DeviceArray(-inf, dtype=float32))
and
_block_det(0.0 * X, 1, 3, 3)
# (DeviceArray(nan, dtype=float32), DeviceArray(nan, dtype=float32))
In _block_det
the NaNs come from the call to inv(S)
in the computation of Sinv_v
which returns infinite values.
In the first call we have
S = np.array([[[-0., -0., -0.], [ 0., -0., 0.]], [[ 0., -0., 0.], [-0., -0., -0.]]])
inv(S) # = DeviceArray([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [1., 1., 1.]]], dtype=float32)
in the second
S = array([[[-0., -0., -0.]]])
inv(S) # = DeviceArray([[[-inf, -inf, -inf]]], dtype=float32)
I am actually not sure which result we should expect here..
I have noticed that jax is pretty lax ;-) about propagating inf and nan. We could add some special cases to handle the inverse of a zero matrix, etc, but I don't think is worth the complexity or extra branch at this point. Note that sparse.inv
already has this disclaimer in its docstring:
We currently assume that the matrix is invertible and you should not
trust the answer unless you know this is true (because jax.numpy.linalg.inv
has this behavior).
Let's just comment out this unit test for now?
I'm fine with commenting that test for now :-)