Mypy error from `next_rng_key` type inconsistency with jax `PRNGKeyArray`
hylkedonker opened this issue · 0 comments
hylkedonker commented
Hi,
It seems that my mypy (version 0.942) is complaining that Haiku's random key generated by hk.next_rng_key()
is not compatible with Jax's PRNGKeyArray
type. The latter are the types of the key
argument in various jax.random
samplers.
Example
import jax
import haiku as hk
def sample_phi(alpha: float):
phi = jax.random.gamma(hk.next_rng_key(), a=alpha)
return phi
Error
example.py:5: error: Argument 1 to "gamma" has incompatible type "ndarray"; expected "Union[Array, PRNGKeyArray]"
Apart from explicitly silencing these errors in mypy, are there any other suggestions to fix these errors?
Thanks in advance,
Hylke
Environment
dm-haiku==0.0.9
jax==0.3.25
jaxlib==0.3.25
mypy==0.942