google/orbax

Basic save/restore using OCDBT not working

Closed this issue · 3 comments

When running this code that does a dead-simple save and restore:

import orbax.checkpoint as ocp
from etils import epath
import jax.numpy as jnp

state = {
    "a": jnp.array(5)
}
ocp.StandardCheckpointHandler().save(epath.Path("orbax-test").resolve(), state)
ocp.StandardCheckpointHandler().restore(epath.Path("orbax-test").resolve())

I get:

ValueError: NOT_FOUND: Error opening "zarr" driver: Metadata at local file "/home/black/monopi/orbax-test/a/.zarray" does not exist [tensorstore_spec='{\"context\":{\"cache_pool\":{},\"data_copy_concurrency\":{},\"file_io_concurrency\":{\"limit\":128},\"file_io_sync\":true},\"driver\":\"zarr\",\"kvstore\":{\"driver\":\"file\",\"path\":\"/home/black/monopi/orbax-test/a/\"}}'] [source locations='tensorstore/driver/kvs_backed_chunk_driver.cc:1255\ntensorstore/driver/driver.cc:109']

It seems like the root cause is that ocp.type_handlers.is_ocdbt_checkpoint(epath.Path("orbax-test").resolve()) == False, even though OCDBT was used during saving. That is because the saved directory structure looks like this:

orbax-test/
    _METADATA
    _sharding
    checkpoint
    ocdbt.process_0/
        d/
        manifest.ocdbt

But ocp.type_handlers.is_ocdbt_checkpoint checks for the existence of manfiest.ocdbt at the top level, rather than looking inside ocdbt.process_0.

I'm on orbax-checkpoint 0.5.15 and tensorstore 0.1.60.

Ah I figured it out, you have to call finalize().

Hi Kevin, please don't directly use any of the CheckpointHandlers like you're doing.

From the docs:

Crucially a CheckpointHandler instance should not be used in isolation, but should always be used in conjunction with a Checkpointer (see below). Otherwise, save operations will not be atomic and async operations cannot be waited upon. This means that in most cases, you will be working with Checkpointer APIs rather than CheckpointHandler APIs.

You want to use ocp.StandardCheckpointer() to save and restore.

If you have a strong reason you're trying to use CheckpointHandler in isolation, please let us know what it is.

The reason I was working directly with the StandardCheckpointHandler is because I was implementing my own custom CheckpointHandler that serialized extra metadata on top of StandardCheckpointHandler. What I was missing was an implementation of finalize() that also called the StandardCheckpointHanlder's finalize() method.