MichaelTMatthews/Craftax

How to load a saved policy

Closed this issue · 1 comments

I was testing out training your PPO-RNN baseline.
I was able to reproduce the results, so that's great!

I saved the model via your --save_policy argument, but I did not see anything to load the policy back once I saved it.

For example, if I trained the model for 1B steps and saved the policy, and then sometime later, I wanted to retrain the model for another 1B steps but start from an existing checkpoint rather than just retrain all 2B again.

Could you show me how I can do this?

Thanks!

Hi aszala, right now there's no functionality for this - the save policy argument was in order to visualise learned policies.
However, it would be quite simple to add this in using the orbax checkpointer - something like this:

from flax.training.train_state import TrainState
from orbax.checkpoint import (
    PyTreeCheckpointer,
    CheckpointManagerOptions,
    CheckpointManager,
)

orbax_checkpointer = PyTreeCheckpointer()
options = CheckpointManagerOptions(max_to_keep=1, create=True)
checkpoint_manager = CheckpointManager(
    os.path.join(path, "policies"), orbax_checkpointer, options
)
train_state = checkpoint_manager.restore(0, items=train_state)

Note you'll have to do this after the train state has been initialised.