`Incompatible checkpoints` error when running `slim_model.py`
danyaljj opened this issue · 1 comments
danyaljj commented
$ python3 slim_model.py --config configs/6B_roto_256.json
WARNING: Logging before InitGoogle() is written to STDERR
I0321 19:57:29.230169 28674 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax devices: 1
jax runtime initialized in 0.0365555s
using checkpoint 383500
/home/danielk/.local/lib/python3.8/site-packages/jax/experimental/maps.py:412: UserWarning: xmap is an experimental feature and probably has bugs!
warn("xmap is an experimental feature and probably has bugs!")
key shape (1, 2)
in shape (1, 2048)
dp 1
mp 1
Total parameters: 6050886880
read from disk/gcs in 129.839s
Traceback (most recent call last):
File "/home/danielk/mesh-transformer-jax-master/mesh_transformer/checkpoint.py", line 164, in read_ckpt
unsharded = _unshard(shards, old_flattened)
File "/home/danielk/mesh-transformer-jax-master/mesh_transformer/checkpoint.py", line 161, in _unshard
assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}"
AssertionError: Incompatible checkpoints (1, 6300, 4096) vs (1, 50400, 4096)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "slim_model.py", line 69, in <module>
network.state = read_ckpt(network.state, f"gs://{bucket}/{model_dir}/step_{ckpt_step}/", devices.shape[1])
File "/home/danielk/mesh-transformer-jax-master/mesh_transformer/checkpoint.py", line 169, in read_ckpt
unsharded = _unshard(shards, old_flattened)
File "/home/danielk/mesh-transformer-jax-master/mesh_transformer/checkpoint.py", line 161, in _unshard
assert x.shape == old.shape, f"Incompatible checkpoints {x.shape} vs {old.shape}"
AssertionError: Incompatible checkpoints (1,) vs (1, 4096)
These are the shards that I downloaded from here: https://mystic.the-eye.eu/public/AI/GPT-J-6B/previous_checkpoints/step_384500/
danyaljj commented
Update: based on the following warning:
I0321 19:57:29.230169 28674 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
jax devices: 1
it seems that some other processes (probably my earlier runs) are using TPUs. So I killed the process shown in sudo lsof -w /dev/accel0
.