google/orbax

Strange behavior of saving sharded trainstate in GCP.

chiamp opened this issue · 3 comments

chiamp commented

A user posted in the Flax discussions about an orbax discrepancy between different zones in GCE. Do different zones have different orbax versions?

==================================================================

what happened

When I save my sharded state in asia-northeast3-a in GCE with orbax, the orbax create /tmp/orbax_ckpt/0/_sharding file which starts with

{"dropout_rng":"{\"sharding_type\": \"NamedSharding\", \"shape\": [2, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}","opt_state.0.0.count":"{\"sharding_type\": \"NamedSharding\", \"shape\": [2, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}",
...

My sharded state has "dropout_rng" state, so above file make sense.

However, when I run same script in other region like asia-southeast1-b, the orbax create _sharding file without proper layer names, for example,

{"ZHJvcG91dF9ybmc=":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}","b3B0X3N0YXRlLjAuMC5jb3VudA==":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}",
...

Theory

I doubt that this is related to OCDBT, because the only difference in between terminal outputs is ocdbt is intitialized in asia-northeast3-a but the other regions are not having this message
type_handlers.py:223] OCDBT is initialized successfully..

I checked tensorstore==0.1.51 in all region.

Anyone can help me please?

Thank you.

Originally posted by @sw32-seo in google/flax#3538

@sw32-seo Can you please check Orbax version in all regions?

@niketkumar I used same docker image which has orbax-checkpoint==0.4.7 for all regions.

Hi @chiamp, thanks for raising the issue. We had to update the pytree key names encoded since special characters like '~' were not encoded properly with the previous encoding.

The names are still proper since it now is under the base64 urlsafe_encode. I'm just curious if you could re-run ur script in asia-northeast3-a and see if the sharding file updates? My suspection is that it will, and will be the same as the one you provided below.

Thanks!