Best Practice for using hk.next_rng_key()
wbrenton opened this issue · 2 comments
wbrenton commented
What is the best practice when using hk.next_rng_key() to assure reproducability? Is there any sort of functionality like hk.set_rng_seed(42)?
tomhennigan commented
Hi @wbrenton , hk.next_rng_key()
is fully deterministic. Keys are split from the key you pass in to init
or apply
. If you pass the same key you will reproduce the same result.
Here is an example:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
def f(x):
return hk.dropout(hk.next_rng_key(), 0.5, x)
f = hk.transform(f)
init_key = jax.random.PRNGKey(42)
x = jnp.ones([4])
params = f.init(init_key, x) # NOTE: Params are empty in this example
# Passing the same key gives the same dropout mask.
key = jax.random.PRNGKey(42)
out1 = f.apply(params, key, x)
out2 = f.apply(params, key, x)
assert (out1 == out2).all()
# Passing a different key gives a different dropout mask.
different_key = jax.random.PRNGKey(123)
out3 = f.apply(params, different_key, x)
assert not (out1 == out3).all()
wbrenton commented
This makes perfect sense! Thank you