patrick-kidger/jaxtyping

Feature request: support for scipy sparse matrix shape type checking

jaanli opened this issue · 3 comments

This library is amazing and has saved me so much work! Thank you!

It would be awesome to support shape checking for scipy sparse matrices as well if possible :)

I think this should already be supported!

import numpy as np
import scipy.sparse as sparse
from jaxtyping import Float

# New sparse arrays
x = sparse.coo_array(np.arange(12.0).reshape(3, 4))
assert isinstance(x, Float[sparse.sparray, "3 4"])

# Old sparse matrices
x = sparse.coo_matrix(np.arange(12.0).reshape(3, 4))
assert isinstance(x, Float[sparse.spmatrix, "3 4"])

Oh amazing, thank you so much @patrick-kidger !!! Guess I just didn't see it in the docs then. Renaming the library would also be great at some point (people are now asking me: "wait, are you using Jax now?" when I'm on PyTorch haha)

Haha! I don't think support for sparse arrays is documented, but really jaxtyping should work with anything that matches the NumPy array API.

In terms of names, yup, I've contemplated this. (This'd probably mean factoring things out into some kind of arraytyping library, and then just re-exporting those types here.) I might do this as some point!