[BUG] Some annotated functions fail to be used inline with `@qjit`
Opened this issue · 1 comments
josh146 commented
For example,
>>> qjit(jax.scipy.linalg.expm)(x)
TypeError: Argument 'ArrayLike' of type <class 'str'> is not a valid JAX type
The same JAX function works fine when used within a defined function that is qjitted.
This is because of unexpected type annotations:
def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array:
paul0403 commented
Is this still open?