PennyLaneAI/catalyst

[BUG] Some annotated functions fail to be used inline with `@qjit`

Opened this issue · 1 comments

For example,

>>> qjit(jax.scipy.linalg.expm)(x)
TypeError: Argument 'ArrayLike' of type <class 'str'> is not a valid JAX type

The same JAX function works fine when used within a defined function that is qjitted.

This is because of unexpected type annotations:

def expm(A: ArrayLike, *, upper_triangular: bool = False, max_squarings: int = 16) -> Array:

Is this still open?