BUG: Exception if array in args is modified in a function with the same name
JuanFMontesinos opened this issue · 4 comments
JuanFMontesinos commented
The code is a thousand words.
from typeguard import typechecked
from jaxtyping import Integer, Array, jaxtyped
import jax.numpy as jnp
@jaxtyped
@typechecked
def fn_1(x: Integer[Array, "N"]) -> None:
_x = x[:25]
return
@jaxtyped
@typechecked
def fn_2(x: Integer[Array, "M"]) -> None:
x = x[:25]
return
x = jnp.arange(50)
y = jnp.arange(50)
fn_1(x)
fn_2(y)
Console output
Traceback (most recent call last):
File "l/jaxt_bug.py", line 21, in <module>
fn_2(y)
File "/env/lru/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 129, in wrapped_fn
return fn(*args, **kwargs)
File "l/jaxt_bug.py", line 14, in fn_2
x = x[:25]
File "/env/lru/lib/python3.10/site-packages/typeguard/_functions.py", line 254, in check_variable_assignment
check_type_internal(value, annotation, memo)
File "/env/lru/lib/python3.10/site-packages/typeguard/_checkers.py", line 680, in check_type_internal
raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
typeguard.TypeCheckError: value assigned to x (jaxlib.xla_extension.ArrayImpl) is not an instance of jaxtyping.Integer[Array, 'M']
patrick-kidger commented
What's your version of typeguard?
JuanFMontesinos commented
typeguard 3.0.2 and jaxtyping 0.2.22
patrick-kidger commented
Right. So I think this is a quirk of typeguard v3: it tries to track the type annotation into the body of the function, and rightly complains that N has changed.
FWIW, I believe typeguard v3 (and v4) still have some bugs that make them poorly compatible with jaxtyping, so we recommend either typeguard v2 or beartype.
JuanFMontesinos commented
Alright, thanks you for your feedback! I may try beartype.