google-deepmind/dm-haiku

Best Practice for using hk.next_rng_key()

wbrenton opened this issue · 2 comments

What is the best practice when using hk.next_rng_key() to assure reproducability? Is there any sort of functionality like hk.set_rng_seed(42)?

Hi @wbrenton , hk.next_rng_key() is fully deterministic. Keys are split from the key you pass in to init or apply. If you pass the same key you will reproduce the same result.

Here is an example:

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

def f(x):
  return hk.dropout(hk.next_rng_key(), 0.5, x)

f = hk.transform(f)

init_key = jax.random.PRNGKey(42)
x = jnp.ones([4])
params = f.init(init_key, x)  # NOTE: Params are empty in this example

# Passing the same key gives the same dropout mask.
key = jax.random.PRNGKey(42)
out1 = f.apply(params, key, x)
out2 = f.apply(params, key, x)
assert (out1 == out2).all()

# Passing a different key gives a different dropout mask.
different_key = jax.random.PRNGKey(123)
out3 = f.apply(params, different_key, x)
assert not (out1 == out3).all()

This makes perfect sense! Thank you