patrick-kidger/jaxtyping

add type annotation for jaxtyped

nahaharo opened this issue · 1 comments

Hello.

I'm currently using jax typing with mypy.

when I using below test code, jaxtyped is untyped.

Here is my test code and mypy.ini.

I think the solution for this issue is adding type annotation to jaxtyped.

from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Float, Array, jaxtyped
from typeguard import typechecked as typechecker
from datetime import datetime


@jaxtyped(typechecker=typechecker)
@partial(jax.jit, static_argnames=["scale"])
def test_linalg(
    a: Float[Array, "m n"], b: Float[Array, "n k"], scale: float
) -> Float[Array, "m k-1"]:
    return scale * (a @ b)[:, :-1]


if __name__ == "__main__":
    a = jnp.array(np.random.randn(5, 5))
    b = jnp.array(np.random.randn(5, 6))

    start_time = datetime.now()
    c = test_linalg(a, b, 2)
    end_time = datetime.now()
    print('Duration: {}'.format(end_time - start_time))

    start_time = datetime.now()
    d = test_linalg(a, b, 2)
    end_time = datetime.now()
    print('Duration: {}'.format(end_time - start_time))
[mypy]
python_version = 3.10
plugins = numpy.typing.mypy_plugin

cache_dir = .mypy_cache/strict
allow_redefinition = True
strict_optional = True
show_error_codes = True
show_column_numbers = True
warn_no_return = True
disallow_any_unimported = True

strict = True
implicit_reexport = False

warn_unused_ignores = False

After reviewing my code, this is not the thing that need to handle in this library.