patrick-kidger/jaxtyping

Static typing issues with ndarray

srush opened this issue · 5 comments

Trying to get jaxtyping to play nicely with static typing.

Float[np.ndarray]  -> fails static types (variable is not allowed in type expressive)
Float[npt.NDArray] ->fails for operator static checks since dtype is unknown
Float[np.NDArray[np.float32]] -> fails static checks since it does not take a dtype arg. 

I looked at bit at the Array base class used for jax and it seems to work perfectly with static typing. But I'm not totally sure why. Any ideas on how to get Numpy to work here?

You want Float[np.ndarray, "...shape..."]. If any shape is acceptable then use ..., e.g. Float[np.ndarray, "..."].

Try running any of the options you have above and jaxtyping will throw an error at runtime.

Sorry just to be clear, I am doing Float[np.ndarray, "...shape..."] but I get a mypy static error Missing type parameters for generic type

I don't get this with the jax version Float[jaxtyping.Array, "...shape..."]

Oh this might be a type-aliasing or versioning issue on my end. Seems like ndarray should work here. I'll try a simpler case.

Huh, so the issue seems to be some sort of typealiasing issue with mypy. This is not a jaxtyping issue, but it is quite myterious to me that it has to be lexically np.ndarray to work as a non-generic type.

image

How odd!
FWIW I'd recommend trying pyright instead -- this is my own strong preference as I've ran into all kinds of issues with mypy before.