Basic save/restore using OCDBT not working
Closed this issue · 3 comments
When running this code that does a dead-simple save and restore:
import orbax.checkpoint as ocp
from etils import epath
import jax.numpy as jnp
state = {
"a": jnp.array(5)
}
ocp.StandardCheckpointHandler().save(epath.Path("orbax-test").resolve(), state)
ocp.StandardCheckpointHandler().restore(epath.Path("orbax-test").resolve())
I get:
ValueError: NOT_FOUND: Error opening "zarr" driver: Metadata at local file "/home/black/monopi/orbax-test/a/.zarray" does not exist [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"file_io_sync\":true},\"driver\":\"zarr\",\"kvstore\":{\"driver\":\"file\",\"path\":\"/home/black/monopi/orbax-test/a/\"}}'] [source locations='tensorstore/driver/kvs_backed_chunk_driver.cc:1255\ntensorstore/driver/driver.cc:109']
It seems like the root cause is that ocp.type_handlers.is_ocdbt_checkpoint(epath.Path("orbax-test").resolve()) == False
, even though OCDBT was used during saving. That is because the saved directory structure looks like this:
orbax-test/
_METADATA
_sharding
checkpoint
ocdbt.process_0/
d/
manifest.ocdbt
But ocp.type_handlers.is_ocdbt_checkpoint
checks for the existence of manfiest.ocdbt
at the top level, rather than looking inside ocdbt.process_0
.
I'm on orbax-checkpoint 0.5.15 and tensorstore 0.1.60.
Ah I figured it out, you have to call finalize()
.
Hi Kevin, please don't directly use any of the CheckpointHandler
s like you're doing.
From the docs:
Crucially a CheckpointHandler instance should not be used in isolation, but should always be used in conjunction with a Checkpointer (see below). Otherwise, save operations will not be atomic and async operations cannot be waited upon. This means that in most cases, you will be working with Checkpointer APIs rather than CheckpointHandler APIs.
You want to use ocp.StandardCheckpointer()
to save and restore.
If you have a strong reason you're trying to use CheckpointHandler in isolation, please let us know what it is.
The reason I was working directly with the StandardCheckpointHandler
is because I was implementing my own custom CheckpointHandler
that serialized extra metadata on top of StandardCheckpointHandler
. What I was missing was an implementation of finalize()
that also called the StandardCheckpointHanlder
's finalize()
method.