DifferentiableUniverseInitiative/jax_cosmo

sparse method `slogdet` returns NaN for null matrix

aboucaud opened this issue · 6 comments

This issue makes the current tests fail on CI.

It corresponds to the specific test in test_sparse.py :

X = np.array(
    [
        [[1, 2, 3], [4, 5, 6], [-1, 7, -2]],
        [[1, 2, 3], [-4, -5, -6], [2, -3, 9]],
        [[7, 8, 9], [5, -4, 6], [-3, -2, -1]],
    ]
)
assert_array_equal(det(0.0 * X), 0.0)

I traced it back to the second call of _block_det() in the for-loop of slogdet() : and can therefore be reproduced with

sparse = 0.0 * X
i = 1
N = 3
P = 3
print(_block_det(sparse, i, N, P))

Ping @dkirkby

EiffL commented

Yes. indeed. and it's preventing the tests from running. It used to work though, so I suspect the problem comes from a new version of JAX breaking everything :-|
I'm gonna open an issue to upgrade the code to Jax 0.2 , hopefully that will also take care of this.

EiffL commented

Ok I've investigated this a little bit further, what happens is that at some point in _block_det some inf gets multiplied by a 0, and that creates a Nan. This is in principle the correct behaviour, but leads to det( 0*X) = nan

I think for now, we could go ahead and skip this test for zero matrices. Let me know what you think, but I don't see a particular use case where we need logdet(0) to be -inf rather than nan.

I cannot think of a particular use case either. But having to handle NaNs in your calculations is quite painful. And you may by chance end up with zeros in a matrix, for witch you don't want your code to break. Especially if it did not break beforehand..

After testing with jax==0.2.0 and 0.2.3 I still get

_block_det(0.0 * X, 0, 3, 3)
# (DeviceArray(0., dtype=float32), DeviceArray(-inf, dtype=float32))

and

_block_det(0.0 * X, 1, 3, 3)
# (DeviceArray(nan, dtype=float32), DeviceArray(nan, dtype=float32))

In _block_det the NaNs come from the call to inv(S) in the computation of Sinv_v which returns infinite values.
In the first call we have

S = np.array([[[-0., -0., -0.], [ 0., -0.,  0.]], [[ 0., -0.,  0.], [-0., -0., -0.]]])
inv(S) # = DeviceArray([[[0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [1., 1., 1.]]], dtype=float32)

in the second

S = array([[[-0., -0., -0.]]])
inv(S) # = DeviceArray([[[-inf, -inf, -inf]]], dtype=float32)

I am actually not sure which result we should expect here..

EiffL commented

I have noticed that jax is pretty lax ;-) about propagating inf and nan. We could add some special cases to handle the inverse of a zero matrix, etc, but I don't think is worth the complexity or extra branch at this point. Note that sparse.inv already has this disclaimer in its docstring:

    We currently assume that the matrix is invertible and you should not
    trust the answer unless you know this is true (because jax.numpy.linalg.inv
    has this behavior).

Let's just comment out this unit test for now?

EiffL commented

I'm fine with commenting that test for now :-)