Error while converting checkpoints to Flax format
Opened this issue · 3 comments
Directory with official LLaMa2 weights consists of checklist.chk, consolidated.00.pth, params.json
. I want to use it to train CoH model and at first try to convert .pth model to Jax weights using your script:
python3 -m coh.scripts.convert_checkpoint \
--load_checkpoint='params::llama-2-7b/consolidated.00.pth' \
--output_file='llama-2-7b-jax/' \
--streaming=True
But it leads to the following error:
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/app/src/coh/scripts/convert_checkpoint.py", line 37, in <module>
utils.run(main)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/app/src/coh/scripts/convert_checkpoint.py", line 22, in main
params = StreamingCheckpointer.load_trainstate_checkpoint(
File "/app/src/coh/tools/checkpoint.py", line 191, in load_trainstate_checkpoint
restored_params = cls.load_checkpoint(
File "/app/src/coh/tools/checkpoint.py", line 107, in load_checkpoint
for key, value in unpacker:
TypeError: cannot unpack non-iterable int object
I created conda environment using your .yml file
@yaraksen
not sure if this is exactly the solution, but I was able to get some thing by changing the StreamingCheckpointer.load_checkpoint()
method in coh.tools.checkpoint.py
:
@staticmethod
def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
if shard_fns is not None:
shard_fns = flatten_dict(
to_state_dict(shard_fns)
)
if remove_dict_prefix is not None:
remove_dict_prefix = tuple(remove_dict_prefix)
flattend_train_state = {}
with utils.open_file(path) as fin:
# 83886080 bytes = 80 MB, which is 16 blocks on GCS
# unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0,
# use_list=False)
# for key, value in unpacker:
# key = tuple(key) # not sure why this is there
# TODO: bug here where unpacker is returning stream of integers, but
# code is expecting tuple of key/value pair of parameter name/values
# this is not save-all solution but instead we load with torch.load()
# all into memory, then iterate through
weight_dict = torch.load(fin)
for key, value in weight_dict.items():
if remove_dict_prefix is not None:
if key[:len(remove_dict_prefix)] == remove_dict_prefix:
key = key[len(remove_dict_prefix):]
else:
continue
key = tuple(key.split("."))
# tensor = from_bytes(None, buff)
tensor = value.tolist() # tensor -> List[float]
if shard_fns is not None:
tensor = shard_fns[key](tensor)
flattend_train_state[key] = tensor
if target is not None:
flattened_target = flatten_dict(
to_state_dict(target), keep_empty_nodes=True
)
for key, value in flattened_target.items():
if key not in flattend_train_state and value == empty_node:
flattend_train_state[key] = value
train_state = unflatten_dict(flattend_train_state)
if target is None:
return train_state
return from_state_dict(target, train_state)
From here I can at least have the weights saved, but I am not entirely sure if the weights are saved in the correct format, since I am running into other jax related issues(#18, but even more issues down the line). If you can confirm these weights format are correct that might help both of us down the line.