Add NNX support for legacy `jax.random.PRNGKey()`
Opened this issue · 1 comments
cisprague commented
Currently, it doesn't seem possible to straightforwardly checkpoint (with Orbax) an NNX module that includes random keys (like with dropout), see google/orbax#1105 (comment). This seems to be due to the new JAX random key type (https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html), which is used here (
Line 186 in fc19c5d
nnx.state(model)
that includes dtype=key<fry>
doesn't seem to be possible.cgarciae commented
Hey @cisprague, maybe you can convert to the old format before serializing?
You could use something like:
def get_key_data(x):
# use jax.random.key_data
serializable_state = jax.tree.map(get_key_data, state)
See PRNGKeys