Parse structure of a saved PyTree checkpoint
minotru opened this issue · 1 comments
Hi,
Is there a way to parse structure of a saved PyTree checkpoint?
I found that there is AbstractCheckpointer.structure
, but it is deprecated.
CONTEXT:
I have a checkpoint saved with orbax's PyTreeCheckpointHandler, it contains sharded jax.Array-s. I am trying to load a checkpoint on a CPU device, so orbax fails to load a checkpoint, because sharding requires 8 devices, while I have only 1 device -- CPU.
Here is where it fails:
>>> checkpoint_manager.restore(checkpoint_manager.latest_step(), items={"state" : None})
...
File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_manager.py", line 472, in restore
restored_items = self._restore_impl(
File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_manager.py", line 504, in _restore_impl
restored[item_name] = self._checkpointers[item_name].restore(
File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpointer.py", line 99, in restore
restored = self._handler.restore(directory, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 1065, in restore
restored_item = asyncio.run(
File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/pytree_checkpoint_handler.py", line 890, in _maybe_deserialize
deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py", line 1260, in deserialize
_deserialize_sharding_from_json_string(serialized_string.item())
File "/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py", line 135, in _deserialize_sharding_from_json_string
np.array(jax.devices()).reshape(shape), axis_names=axis_names
ValueError: cannot reshape array of size 1 into shape (4,2)
To load a checkpoint with different sharding, I need to pass restore_args
-- a tree of ArrayRestoreArgs
with the same structure, as the saved checkpoint.
The problem is that I do not know the structure of the saved checkpoint, thus I can't create restore_args
of proper structure.
Digging into orbax's source code showed that PyTreeCheckpointHandler
uses _get_internal_metadata
to get item structure, but it is a private method.
So:
- What is the right way to load checkpoint with a different sharding without knowledge of checkpoint structure?
- Is there a public method to parse checkpoint structure?
Thanks!