[Bug] pinv for Diagonal operators returns infs that should be zero
f0uriest opened this issue ยท 0 comments
f0uriest commented
๐ Bug
Basically, cola.linalg.pinv(cola.ops.Diagonal)
seems to just be doing an (unmasked) elementwise inverse, which gives the wrong result for rank deficient operators.
To reproduce
np.diag(cola.linalg.pinv(cola.ops.Diagonal(np.arange(5))).to_dense())
# array([ inf, 1. , 0.5 , 0.33333333, 0.25 ])
# if we explicitly make it a dense matrix it calls np.linalg.lstsq which gives the expected result
np.diag(cola.linalg.pinv(cola.ops.Dense(np.diag(np.arange(5)))).to_dense())
# array([0. , 1. , 0.5 , 0.33333333, 0.25 ])
Expected Behavior
I think the "standard" behavior for this case would be to return 0 in place of inf (or anywhere abs(A[i,i]) < eps*max(A.shape)
, which is what happens if you actually call np.linalg.pinv
System information
Please complete the following information:
- CoLA Version '0.0.6'
- JAX Version '0.4.31'