Misleading exception when runtime type checker is used directly
padix-key opened this issue · 4 comments
When I run
from jaxtyping import Integer, jaxtyped
import numpy as np
from beartype import beartype
@beartype
#@jaxtyped(typechecker=beartype) # Works fine with this line
def append_one(array: Integer[np.ndarray, "dim"]) -> Integer[np.ndarray, "dim+1"]:
return np.append(array, 1)
append_one(np.array([1, 2]))
I get
jaxtyping.AnnotationError: Cannot process symbolic axis 'dim+1' as some axis names have not been processed. In practice you should usually only use symbolic axes in annotations for return types, referring only to axes annotated for arguments.
When I replace @beartype
with @jaxtyped(typechecker=beartype)
the code runs fine. However, the error seems a bit misleading to me: Instead of pointing me to the correct decorator, it indicates an error in the shape annotation.
This error was especially misleading to me as in the beartype documentation jaxtyping
support is emphasized, giving the impression that jaxtyping
annotations work out of the box (which they actually do, except the symbolic expressions).
Ah, it sounds like we should include a message like "have you included a jaxtyped decorator?"
Happy to take a PR tweaking this!
I am not sure where the correct location would be. To me it seems like in get_shape_memo()
jaxtyping/jaxtyping/_storage.py
Lines 33 to 46 in c2f19db
this assumption is not correct:
In this case just create a temporary memo, since we're not going to be comparing against any stored values anyway.
In the erroneous case I have an empty shape_memo
, but it still compares against it, as beartype
is used. Is this correct?
So I'm just suggesting tweaking the error message you bumped into, which is available on this line:
jaxtyping/jaxtyping/_array_types.py
Line 144 in c2f19db
OK, is created a PR for this.