add type annotation for jaxtyped
nahaharo opened this issue · 1 comments
nahaharo commented
Hello.
I'm currently using jax typing with mypy.
when I using below test code, jaxtyped is untyped.
Here is my test code and mypy.ini.
I think the solution for this issue is adding type annotation to jaxtyped.
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jaxtyping import Float, Array, jaxtyped
from typeguard import typechecked as typechecker
from datetime import datetime
@jaxtyped(typechecker=typechecker)
@partial(jax.jit, static_argnames=["scale"])
def test_linalg(
a: Float[Array, "m n"], b: Float[Array, "n k"], scale: float
) -> Float[Array, "m k-1"]:
return scale * (a @ b)[:, :-1]
if __name__ == "__main__":
a = jnp.array(np.random.randn(5, 5))
b = jnp.array(np.random.randn(5, 6))
start_time = datetime.now()
c = test_linalg(a, b, 2)
end_time = datetime.now()
print('Duration: {}'.format(end_time - start_time))
start_time = datetime.now()
d = test_linalg(a, b, 2)
end_time = datetime.now()
print('Duration: {}'.format(end_time - start_time))
[mypy]
python_version = 3.10
plugins = numpy.typing.mypy_plugin
cache_dir = .mypy_cache/strict
allow_redefinition = True
strict_optional = True
show_error_codes = True
show_column_numbers = True
warn_no_return = True
disallow_any_unimported = True
strict = True
implicit_reexport = False
warn_unused_ignores = False
nahaharo commented
After reviewing my code, this is not the thing that need to handle in this library.