google/orbax

New interface does not support empty dicts in pytrees

Closed this issue · 1 comments

import orbax.checkpoint as ocp

pytree = {'a': 0, 'b':{}}
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
  ocp.test_utils.create_empty('/tmp/ckpt22/'),
  options=options,
)

mngr.save(0, args=ocp.args.StandardSave(pytree))
# mngr.save(0, args=ocp.args.PyTreeSave(pytree)) also warns

raises the warning that aggregate is no longer supported.

However, I suspect this should not.

This is probably because you are using is_empty_or_leaf in the code checking the save args inside of pytree_checkpoint_handler.py.

    # Because of empty states, the user-provided args may not contain
    # all necessary arguments. These should be filled in with default args.
    save_args = jax.tree_util.tree_map(
        _maybe_set_default_save_args,
        item,
        item if save_args is None else save_args,
        is_leaf=utils.is_empty_or_leaf,
    )

A fix has been pushed. Should be available in release 0.5.7. Please let us know if the issue persists.