patrick-kidger/jaxtyping

How to type a "DTypeLike" argument and runtime check it ?

pchanial opened this issue · 1 comments

When I run the following code through beartype :

from jaxtyping import DTypeLike
import jax.numpy as jnp

class Foo:
    n = 10
    def ones(self, dtype : DTypeLike | None = None) -> Shaped[Array, '...']:
       return jnp.ones(self.n, dtype=dtype)

I get the following error :

E   beartype.roar.BeartypeDecorHintPep3119Exception: Method ...check_return() parameter "dtype" type hint <class 'jax._src.typing.SupportsDType'> uncheckable at runtime (i.e., not passable as second parameter to isinstance(), due to raising "TypeError: Instance and class checks can only be used with @runtime_checkable protocols" from metaclass __instancecheck__() method).

I understand that DTypeLike is imported in jaxtyping from jax.typing, but it there a way to make the above code compliant with runtime checkers using jaxtyping ?

Thanks for the report!

Looks like the underlying jax.typing.DTypeLike needs to add the typing.runtime_checkable decorator. I'd suggest opening an issue (or one-line PR) on the main JAX repo.