New interface does not support empty dicts in pytrees
Closed this issue · 1 comments
PhilipVinc commented
import orbax.checkpoint as ocp
pytree = {'a': 0, 'b':{}}
options = ocp.CheckpointManagerOptions()
mngr = ocp.CheckpointManager(
ocp.test_utils.create_empty('/tmp/ckpt22/'),
options=options,
)
mngr.save(0, args=ocp.args.StandardSave(pytree))
# mngr.save(0, args=ocp.args.PyTreeSave(pytree)) also warns
raises the warning that aggregate is no longer supported.
However, I suspect this should not.
This is probably because you are using is_empty_or_leaf
in the code checking the save args inside of pytree_checkpoint_handler.py
.
# Because of empty states, the user-provided args may not contain
# all necessary arguments. These should be filled in with default args.
save_args = jax.tree_util.tree_map(
_maybe_set_default_save_args,
item,
item if save_args is None else save_args,
is_leaf=utils.is_empty_or_leaf,
)
niketkumar commented
A fix has been pushed. Should be available in release 0.5.7
. Please let us know if the issue persists.