How to type a "DTypeLike" argument and runtime check it ?
pchanial opened this issue · 1 comments
pchanial commented
When I run the following code through beartype :
from jaxtyping import DTypeLike
import jax.numpy as jnp
class Foo:
n = 10
def ones(self, dtype : DTypeLike | None = None) -> Shaped[Array, '...']:
return jnp.ones(self.n, dtype=dtype)
I get the following error :
E beartype.roar.BeartypeDecorHintPep3119Exception: Method ...check_return() parameter "dtype" type hint <class 'jax._src.typing.SupportsDType'> uncheckable at runtime (i.e., not passable as second parameter to isinstance(), due to raising "TypeError: Instance and class checks can only be used with @runtime_checkable protocols" from metaclass __instancecheck__() method).
I understand that DTypeLike is imported in jaxtyping from jax.typing, but it there a way to make the above code compliant with runtime checkers using jaxtyping ?
patrick-kidger commented
Thanks for the report!
Looks like the underlying jax.typing.DTypeLike
needs to add the typing.runtime_checkable
decorator. I'd suggest opening an issue (or one-line PR) on the main JAX repo.