google/flax

Add NNX support for legacy `jax.random.PRNGKey()`

Opened this issue · 1 comments

Currently, it doesn't seem possible to straightforwardly checkpoint (with Orbax) an NNX module that includes random keys (like with dropout), see google/orbax#1105 (comment). This seems to be due to the new JAX random key type (https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html), which is used here (

key = jax.random.key(value)
) in NNX. Although Orbax has added individual support for the new type (see google/orbax#620), saving nnx.state(model) that includes dtype=key<fry> doesn't seem to be possible.

Hey @cisprague, maybe you can convert to the old format before serializing?
You could use something like:

def get_key_data(x):
  # use jax.random.key_data
  
serializable_state = jax.tree.map(get_key_data, state)

See PRNGKeys