wilson-labs/cola

[Bug] pinv for Diagonal operators returns infs that should be zero

f0uriest opened this issue ยท 0 comments

๐Ÿ› Bug

Basically, cola.linalg.pinv(cola.ops.Diagonal) seems to just be doing an (unmasked) elementwise inverse, which gives the wrong result for rank deficient operators.

To reproduce

np.diag(cola.linalg.pinv(cola.ops.Diagonal(np.arange(5))).to_dense())
# array([       inf, 1.        , 0.5       , 0.33333333, 0.25      ])

# if we explicitly make it a dense matrix it calls np.linalg.lstsq which gives the expected result
np.diag(cola.linalg.pinv(cola.ops.Dense(np.diag(np.arange(5)))).to_dense())
# array([0.        , 1.        , 0.5       , 0.33333333, 0.25      ])

Expected Behavior

I think the "standard" behavior for this case would be to return 0 in place of inf (or anywhere abs(A[i,i]) < eps*max(A.shape), which is what happens if you actually call np.linalg.pinv

System information

Please complete the following information:

  • CoLA Version '0.0.6'
  • JAX Version '0.4.31'