google-deepmind/dm-haiku

Mypy error from `next_rng_key` type inconsistency with jax `PRNGKeyArray`

hylkedonker opened this issue · 0 comments

Hi,

It seems that my mypy (version 0.942) is complaining that Haiku's random key generated by hk.next_rng_key() is not compatible with Jax's PRNGKeyArray type. The latter are the types of the key argument in various jax.random samplers.

Example

import jax
import haiku as hk

def sample_phi(alpha: float):
    phi = jax.random.gamma(hk.next_rng_key(), a=alpha)
    return phi

Error

example.py:5: error: Argument 1 to "gamma" has incompatible type "ndarray"; expected "Union[Array, PRNGKeyArray]"

Apart from explicitly silencing these errors in mypy, are there any other suggestions to fix these errors?

Thanks in advance,

Hylke

Environment

dm-haiku==0.0.9
jax==0.3.25
jaxlib==0.3.25
mypy==0.942