
Is there a good way to save/load & compress/decompress model weights?

Closed this issue · 10 comments

Hey- This is Chris.
I'm using this open-source for my project.

Since I'm new to JAX and haiku, I have some questions.

Is there a good way to save/load & compress/decompress & serialize model weights?

  • save/load model (network only or weight only)
  • compress/decompress weights
  • serialize

I think serialization is an important issue on scalability. Can you give me some keywords or hints about this issue?


Hey, we're intentionally un-opinionated here. I will note:

  1. Haiku params (and network state) are transparent dictionaries of JAX jnp.ndarrays.
  2. jnp.ndarray converts to np.ndarray, so when using non-bfloat16 types, anything that works to save NumPy will work here.

There are a few options we've seen work well:

  • Directly pickle the params dict. Upside: it just works, downside: may not be totally efficient, and has usual pickle caveats.
  • Use or np.savez to store the ndarrays in a flat format, and save the tree structure via either pickle or a stable serialized format (protobuf, json, yaml, you name it.)

I'll look into extending either the Transformer or ResNet example with checkpointing, so we have a concrete piece of code that we can point people to as an example.

I'll leave this bug open conditioned on that - hope this helps!

Thanks for the help! @trevorcai

Your advice helped me a lot!

I'm planning to try serialization via protobuf over gRPC communication.
and for the checkpointing, I'll wait for your examples :)

Thanks! @trevorcai

I made encoder & decoder for the haiku model weights and trajectories for gRPC protobuf message.
I noticed that you used frozendict for the data structure of model weights.
And there was a comment on this data type.

Is this data type going to be deprecated?

# TODO(lenamartens) Deprecate type

Hey, nice job! That's correct, we'd like to replace it with the FlatMapping class below it.
The bit is ready to be flipped, we'll look to flip it soon if we can.

Quick update - it turns out the bit is not ready to be flipped, there are a couple edge cases that need to be fixed. We don't really have the time to look into this for now, so don't expect it to flip in the near future.

I've been using which seems to be working well.

What library is recommended for directly serializing the params dict? What are the caveats? I think adding these to the docs will be a nice addition, or at least links to other good docs on serialization in Python.

If you use HAIKU_FLATMAPPING=0, then Haiku checkpointing is as simple as serializing dicts of np.ndarrays; any solution that works for that will work for Haiku.

The transformer example is a simple demonstration of pickle-ing the entire state:

Two years on, the vast majority of people at DeepMind use to store the np.ndarrays in a flat format, and save the tree structure separately through pickle or a specialized internal format (that I don't know the details of because I use pickle).

@trevorcai Is there an example of saving the tree structure and then loading the np.ndarrays back into it?

def save(ckpt_dir: str, state) -> None:
 with open(os.path.join(ckpt_dir, "arrays.npy"), "wb") as f:
   for x in jax.tree_leaves(state):, x, allow_pickle=False)

 tree_struct = jax.tree_map(lambda t: 0, state)
 with open(os.path.join(ckpt_dir, "tree.pkl"), "wb") as f:
   pickle.dump(tree_struct, f)

def restore(ckpt_dir):
 with open(os.path.join(ckpt_dir, "tree.pkl"), "rb") as f:
   tree_struct = pickle.load(f)
 leaves, treedef = jax.tree_flatten(tree_struct)
 with open(os.path.join(ckpt_dir, "arrays.npy"), "rb") as f:
   flat_state = [np.load(f) for _ in leaves]

 return jax.tree_unflatten(treedef, flat_state)

typed out in the comment-box so may need some adjustment to actually run.