patrick-kidger/jaxtyping

BUG: Exception if array in args is modified in a function with the same name

JuanFMontesinos opened this issue · 4 comments

The code is a thousand words.

from typeguard import typechecked
from jaxtyping import Integer, Array, jaxtyped
import jax.numpy as jnp

@jaxtyped
@typechecked
def fn_1(x: Integer[Array, "N"]) -> None:
    _x = x[:25]
    return

@jaxtyped
@typechecked
def fn_2(x: Integer[Array, "M"]) -> None:
    x = x[:25]
    return

x = jnp.arange(50)
y = jnp.arange(50)

fn_1(x)
fn_2(y)

Console output

Traceback (most recent call last):
  File "l/jaxt_bug.py", line 21, in <module>
    fn_2(y)
  File "/env/lru/lib/python3.10/site-packages/jaxtyping/_decorator.py", line 129, in wrapped_fn
    return fn(*args, **kwargs)
  File "l/jaxt_bug.py", line 14, in fn_2
    x = x[:25]
  File "/env/lru/lib/python3.10/site-packages/typeguard/_functions.py", line 254, in check_variable_assignment
    check_type_internal(value, annotation, memo)
  File "/env/lru/lib/python3.10/site-packages/typeguard/_checkers.py", line 680, in check_type_internal
    raise TypeCheckError(f"is not an instance of {qualified_name(origin_type)}")
typeguard.TypeCheckError: value assigned to x (jaxlib.xla_extension.ArrayImpl) is not an instance of jaxtyping.Integer[Array, 'M']

What's your version of typeguard?

typeguard 3.0.2 and jaxtyping 0.2.22

Right. So I think this is a quirk of typeguard v3: it tries to track the type annotation into the body of the function, and rightly complains that N has changed.

FWIW, I believe typeguard v3 (and v4) still have some bugs that make them poorly compatible with jaxtyping, so we recommend either typeguard v2 or beartype.

Alright, thanks you for your feedback! I may try beartype.