google/orbax

Support checkpointing new-style `jax.random.key`

hylkedonker opened this issue · 6 comments

According to JEP 9263, jax.random.PRNGKey will be deprecated in favour of jax.random.key. However, it seems that Orbax can currently only checkpoint old-style keys. Trying to checkpoint a jax.random.key raises the exception

TypeError: Cannot interpret 'key' as a data type

Here is a minimal example tested on HEAD:

from pathlib import Path
from tempfile import TemporaryDirectory
import jax
import orbax.checkpoint

orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()

# Try to checkpoint JAX new-style pseudo-random number generator key.
key = jax.random.key(42)

# OK:
# The old-style bit data of the PRNGKey array works as expected.
ckpt_oldstyle = {'key': jax.random.key_data(key)}
with TemporaryDirectory() as tmpdir:
    orbax_checkpointer.save(Path(tmpdir) / 'checkpoint', ckpt_oldstyle)

# Fails:
# TypeError: Cannot interpret 'key<fry>' as a data type
ckpt_newstyle = {'key': key}
with TemporaryDirectory() as tmpdir:
    orbax_checkpointer.save(Path(tmpdir) / 'checkpoint', ckpt_newstyle)

Are there any plans to also support jax.random.key?
Thanks in advance,
Hylke

You are not the first person to request this, so I'll just say that the issue is on our radar, but has not yet risen to a high priority. In the meantime, you'll probably have to convert the key to a jax.Array to save it.

Thanks for reporting though, this will affect our prioritization going forward.

Great, thanks for your response!

@hylkedonker I am looking into adding support to store jax.random.key in Orbax. I have a couple questions on how these states are stored?

  1. How often should the random keys be saved? Is it necessary to store them in every training step?
  2. Do all machines share the same keys, or does each machine need to store its own ones?

Thanks for getting in touch.
I use the pseudo random number generator keys to train variational inference (VI) models. Concretely, each training step I consume a PRNG key to make a Monte Carlo estimate of the ELBO (evidence lower bound). In practice, I make the key part of Flax's TrainState which I checkpoint every now and then.
So to get back to your questions:

  1. The key needs to be tracked every training step, but I don't save the TrainState every training step.
  2. I currently don't have a lot of experience sharding the computation across different machines. But I would imagine that one might pmap the Monte Carlo estimate over different machines (so that each machine gets its own key).

I hope this helps. If not, let me know how I can further clarify.

Thanks for your sharing. We have a new release of orbax-checkpoint 0.5.1 including two new JaxRandomKeyCheckpointHanlder and NumpyRandomKeyCheckpointHandler. We recommend to use these outside of train state PyTree because the random keys are more metadata.

Documentation is here. Usage examples can be found here

Great work, thanks!