[Bug] Jax `Sparse` Matrix Example in Documentation Throws Error
mabilton opened this issue ยท 4 comments
๐ Bug
One of the Jax sparse matrix examples in the documentation (namely https://cola.readthedocs.io/en/latest/package/cola.ops.html#cola.ops.Sparse) throws an error.
To reproduce
** Code snippet to reproduce **
import jax.numpy as jnp
import cola
data = jnp.array([1, 2, 3, 4, 5, 6])
indices = jnp.array([0, 2, 1, 0, 2, 1])
indptr = jnp.array([0, 2, 4, 6])
shape = (3, 3)
op = cola.ops.Sparse(data, indices, indptr, shape)
** Stack trace/error message **
AttributeError: module 'cola.backends.jax_fns' has no attribute 'sparse_csr'
Expected Behavior
That a Sparse
matrix is returned.
System information
cola
version:0.0.4
jax
version:0.4.14
- OS:
Pop!_OS 22.04 LTS
Additional context
First I'd like to say that I think the idea behind the library is really cool and that I can definitely see myself utilising it across a lot of my projects :).
The fix itself should be as simple defining sparse_csr
inside of cola.backends.jax_fns
(plus adding a unit test, which should probably also be done for any other examples in the documentation which are also lacking unit tests), which I'm happy to do over the next couple of days.
As an aside, does cola
intend on supporting sparsity formats other than CSR? I know that both jax
and pytorch
support COO, CSR, CSC, BSR, and BSC formats (see https://jax.readthedocs.io/en/latest/jax.experimental.sparse.html#other-sparse-data-structures and https://pytorch.org/docs/stable/sparse.html), so I imagine it would make sense to allow users to explicitly specify which sparsity representation they want. Any thoughts on this?
Thanks for any help.
Cheers,
Matt.
Hi @mabilton
Thanks for opening the issue!
You're right, sparse is mostly just a placeholder right now as it only works on torch (and is not necessarily the general interface that we want).
Indeed I think it would be a good idea if we support the varied sparse types that jax and pytorch have to offer. Possibly the way we should go here would mostly be just to wrap the sparse types much in the way that Dense
wraps the dense jax and pytorch arrays. It probably should be possible to consider sparse vectors too. After tackling the algorithm redesign #42, we will circle back to this. Or if you wanted to take a start on this we'd be happy to review a pull request!
Hey @mfinzi.
Completely agree that it makes sense to push through the refactorisations in #42 before tackling this - let me know if there's anyway I can help out with #42 :) .
In the mean time, do you think it would make sense to remove the Sparse
example that I linked from the documentation until support for Sparse
matrices has been more completely implemented? In a similar vein, would it also be prudent remove the 'tick' under the 'Sparse' column from the 'Features implemented' table in the README
?