wilson-labs/cola

[Bug] Jax `Sparse` Matrix Example in Documentation Throws Error

mabilton opened this issue ยท 4 comments

๐Ÿ› Bug

One of the Jax sparse matrix examples in the documentation (namely https://cola.readthedocs.io/en/latest/package/cola.ops.html#cola.ops.Sparse) throws an error.

To reproduce

** Code snippet to reproduce **

import jax.numpy as jnp
import cola
data = jnp.array([1, 2, 3, 4, 5, 6])
indices = jnp.array([0, 2, 1, 0, 2, 1])
indptr = jnp.array([0, 2, 4, 6])
shape = (3, 3)
op = cola.ops.Sparse(data, indices, indptr, shape)

** Stack trace/error message **

AttributeError: module 'cola.backends.jax_fns' has no attribute 'sparse_csr'

Expected Behavior

That a Sparse matrix is returned.

System information

  • cola version: 0.0.4
  • jax version: 0.4.14
  • OS: Pop!_OS 22.04 LTS

Additional context

First I'd like to say that I think the idea behind the library is really cool and that I can definitely see myself utilising it across a lot of my projects :).

The fix itself should be as simple defining sparse_csr inside of cola.backends.jax_fns (plus adding a unit test, which should probably also be done for any other examples in the documentation which are also lacking unit tests), which I'm happy to do over the next couple of days.

As an aside, does cola intend on supporting sparsity formats other than CSR? I know that both jax and pytorch support COO, CSR, CSC, BSR, and BSC formats (see https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html#other-sparse-data-structures and https://pytorch.org/docs/stable/sparse.html), so I imagine it would make sense to allow users to explicitly specify which sparsity representation they want. Any thoughts on this?

Thanks for any help.

Cheers,
Matt.

mfinzi commented

Hi @mabilton
Thanks for opening the issue!
You're right, sparse is mostly just a placeholder right now as it only works on torch (and is not necessarily the general interface that we want).
Indeed I think it would be a good idea if we support the varied sparse types that jax and pytorch have to offer. Possibly the way we should go here would mostly be just to wrap the sparse types much in the way that Dense wraps the dense jax and pytorch arrays. It probably should be possible to consider sparse vectors too. After tackling the algorithm redesign #42, we will circle back to this. Or if you wanted to take a start on this we'd be happy to review a pull request!

Hey @mfinzi.

Completely agree that it makes sense to push through the refactorisations in #42 before tackling this - let me know if there's anyway I can help out with #42 :) .

In the mean time, do you think it would make sense to remove the Sparse example that I linked from the documentation until support for Sparse matrices has been more completely implemented? In a similar vein, would it also be prudent remove the 'tick' under the 'Sparse' column from the 'Features implemented' table in the README?

mfinzi commented

Yes, good point. I will remove that tick. We just finished up #42 and should be able to return to this now

Hi all, I'm here doing some house keeping. Thanks for pointing this bug out. I've just made a PR #98 that fixed the JAX Sparse issues.