patrick-kidger/jaxtyping

Misleading exception when runtime type checker is used directly

padix-key opened this issue · 4 comments

When I run

from jaxtyping import Integer, jaxtyped
import numpy as np
from beartype import beartype


@beartype
#@jaxtyped(typechecker=beartype)  # Works fine with this line
def append_one(array: Integer[np.ndarray, "dim"]) -> Integer[np.ndarray, "dim+1"]:
    return np.append(array, 1)


append_one(np.array([1, 2]))

I get

jaxtyping.AnnotationError: Cannot process symbolic axis 'dim+1' as some axis names have not been processed. In practice you should usually only use symbolic axes in annotations for return types, referring only to axes annotated for arguments.

When I replace @beartype with @jaxtyped(typechecker=beartype) the code runs fine. However, the error seems a bit misleading to me: Instead of pointing me to the correct decorator, it indicates an error in the shape annotation.
This error was especially misleading to me as in the beartype documentation jaxtyping support is emphasized, giving the impression that jaxtyping annotations work out of the box (which they actually do, except the symbolic expressions).

Ah, it sounds like we should include a message like "have you included a jaxtyped decorator?"

Happy to take a PR tweaking this!

I am not sure where the correct location would be. To me it seems like in get_shape_memo()

def get_shape_memo():
if _has_shape_memo():
single_memo, variadic_memo, pytree_memo, arguments = _shape_storage.memo_stack[
-1
]
else:
# `isinstance` happening outside any @jaxtyped decorators, e.g. at the
# global scope. In this case just create a temporary memo, since we're not
# going to be comparing against any stored values anyway.
single_memo = {}
variadic_memo = {}
pytree_memo = {}
arguments = {}
return single_memo, variadic_memo, pytree_memo, arguments

this assumption is not correct:

In this case just create a temporary memo, since we're not going to be comparing against any stored values anyway.

In the erroneous case I have an empty shape_memo, but it still compares against it, as beartype is used. Is this correct?

So I'm just suggesting tweaking the error message you bumped into, which is available on this line:

raise AnnotationError(

OK, is created a PR for this.