Checkpoint Manager using different directory paths for save and restore
svarunid opened this issue · 2 comments
I was trying to save and restore model and opt_state using checkpoint manager. I noticed two issues. While saving the checkpoint manager creates a temp directory path and saves in that location.
tmp_step_dir = self._create_tmp_directory(save_directory)
This temporary directory adds an extra time stamp to the path we pass in during the initialization of checkpoint manager.
However, while restoring, the correct directory path is not resolved and this causes an directory not found issue.
directory = self.directory
path = self._get_save_directory(step, directory, item_name)
self._checkpointers[item_name].restore(
path, item=item, **kwargs
)
Here's my code:
epath.Path('/nmt-attention-checkpoints/')
mngr_options = ocp.CheckpointManagerOptions(
max_to_keep=3,
save_interval_steps=25
)
mngr = ocp.CheckpointManager(
path,
{
"model": ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler()),
"opt_state": ocp.AsyncCheckpointer(ocp.PyTreeCheckpointHandler())
},
mngr_options
)
mngr.save(
0,
{
"model": model,
"opt_state": opt_state
}
)
mngr.restore(step=mngr.latest_step())
Path were the checkpoints are saved: nmt-attention-checkpoints\0.orbax-checkpoint-tmp-1703334121930256
I get the follwing error while restoring my checkpoints: Checkpoint at \nmt-attention-checkpoints\0\model not found.
You're checkpointing asynchronously and restoring without waiting for the background save operation to complete. Add a wait_until_finished
call before restoring. The orbax-checkpoint-tmp...
suffix indicates that the checkpoint is not complete and cannot yet be restored (or if there was a failure before finalization, the checkpoint is likely garbage).
Thanks! I didn't notice that CheckpointManager
has its own wait_until_finished
method.