Support checkpointing new-style `jax.random.key`
hylkedonker opened this issue · 6 comments
According to JEP 9263, jax.random.PRNGKey
will be deprecated in favour of jax.random.key
. However, it seems that Orbax can currently only checkpoint old-style keys. Trying to checkpoint a jax.random.key
raises the exception
TypeError: Cannot interpret 'key' as a data type
Here is a minimal example tested on HEAD
:
from pathlib import Path
from tempfile import TemporaryDirectory
import jax
import orbax.checkpoint
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
# Try to checkpoint JAX new-style pseudo-random number generator key.
key = jax.random.key(42)
# OK:
# The old-style bit data of the PRNGKey array works as expected.
ckpt_oldstyle = {'key': jax.random.key_data(key)}
with TemporaryDirectory() as tmpdir:
orbax_checkpointer.save(Path(tmpdir) / 'checkpoint', ckpt_oldstyle)
# Fails:
# TypeError: Cannot interpret 'key<fry>' as a data type
ckpt_newstyle = {'key': key}
with TemporaryDirectory() as tmpdir:
orbax_checkpointer.save(Path(tmpdir) / 'checkpoint', ckpt_newstyle)
Are there any plans to also support jax.random.key
?
Thanks in advance,
Hylke
You are not the first person to request this, so I'll just say that the issue is on our radar, but has not yet risen to a high priority. In the meantime, you'll probably have to convert the key to a jax.Array
to save it.
Thanks for reporting though, this will affect our prioritization going forward.
Great, thanks for your response!
@hylkedonker I am looking into adding support to store jax.random.key in Orbax. I have a couple questions on how these states are stored?
- How often should the random keys be saved? Is it necessary to store them in every training step?
- Do all machines share the same keys, or does each machine need to store its own ones?
Thanks for getting in touch.
I use the pseudo random number generator keys to train variational inference (VI) models. Concretely, each training step I consume a PRNG key to make a Monte Carlo estimate of the ELBO (evidence lower bound). In practice, I make the key part of Flax's TrainState
which I checkpoint every now and then.
So to get back to your questions:
- The key needs to be tracked every training step, but I don't save the
TrainState
every training step. - I currently don't have a lot of experience sharding the computation across different machines. But I would imagine that one might
pmap
the Monte Carlo estimate over different machines (so that each machine gets its own key).
I hope this helps. If not, let me know how I can further clarify.
Thanks for your sharing. We have a new release of orbax-checkpoint 0.5.1 including two new JaxRandomKeyCheckpointHanlder and NumpyRandomKeyCheckpointHandler. We recommend to use these outside of train state PyTree because the random keys are more metadata.
Great work, thanks!