Static typing issues with ndarray
srush opened this issue · 5 comments
Trying to get jaxtyping to play nicely with static typing.
Float[np.ndarray] -> fails static types (variable is not allowed in type expressive)
Float[npt.NDArray] ->fails for operator static checks since dtype is unknown
Float[np.NDArray[np.float32]] -> fails static checks since it does not take a dtype arg.
I looked at bit at the Array base class used for jax and it seems to work perfectly with static typing. But I'm not totally sure why. Any ideas on how to get Numpy to work here?
You want Float[np.ndarray, "...shape..."]
. If any shape is acceptable then use ...
, e.g. Float[np.ndarray, "..."]
.
Try running any of the options you have above and jaxtyping will throw an error at runtime.
Sorry just to be clear, I am doing Float[np.ndarray, "...shape..."]
but I get a mypy static error Missing type parameters for generic type
I don't get this with the jax version Float[jaxtyping.Array, "...shape..."]
Oh this might be a type-aliasing or versioning issue on my end. Seems like ndarray should work here. I'll try a simpler case.
How odd!
FWIW I'd recommend trying pyright instead -- this is my own strong preference as I've ran into all kinds of issues with mypy before.