forhaoliu/chain-of-hindsight

Error while converting checkpoints to Flax format

Opened this issue · 3 comments

Directory with official LLaMa2 weights consists of checklist.chk, consolidated.00.pth, params.json. I want to use it to train CoH model and at first try to convert .pth model to Jax weights using your script:

python3 -m coh.scripts.convert_checkpoint \
    --load_checkpoint='params::llama-2-7b/consolidated.00.pth' \
    --output_file='llama-2-7b-jax/' \
    --streaming=True

But it leads to the following error:

Traceback (most recent call last):
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/app/src/coh/scripts/convert_checkpoint.py", line 37, in <module>
    utils.run(main)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/app/src/coh/scripts/convert_checkpoint.py", line 22, in main
    params = StreamingCheckpointer.load_trainstate_checkpoint(
  File "/app/src/coh/tools/checkpoint.py", line 191, in load_trainstate_checkpoint
    restored_params = cls.load_checkpoint(
  File "/app/src/coh/tools/checkpoint.py", line 107, in load_checkpoint
    for key, value in unpacker:
TypeError: cannot unpack non-iterable int object

I created conda environment using your .yml file

Same issue here, @yaraksen were you able to get past the error?

@PootieT, unfortunately not, waiting for the author's response

@yaraksen
not sure if this is exactly the solution, but I was able to get some thing by changing the StreamingCheckpointer.load_checkpoint() method in coh.tools.checkpoint.py :

    @staticmethod
    def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
        if shard_fns is not None:
            shard_fns = flatten_dict(
                to_state_dict(shard_fns)
            )
        if remove_dict_prefix is not None:
            remove_dict_prefix = tuple(remove_dict_prefix)
        flattend_train_state = {}
        with utils.open_file(path) as fin:
            # 83886080 bytes = 80 MB, which is 16 blocks on GCS
            # unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0,
            #                             use_list=False)
            # for key, value in unpacker:
            #     key = tuple(key)  # not sure why this is there
            # TODO: bug here where unpacker is returning stream of integers, but
            # code is expecting tuple of key/value pair of parameter name/values
            # this is not save-all solution but instead we load with torch.load()
            # all into memory, then iterate through
            weight_dict = torch.load(fin)
            for key, value in weight_dict.items():
                if remove_dict_prefix is not None:
                    if key[:len(remove_dict_prefix)] == remove_dict_prefix:
                        key = key[len(remove_dict_prefix):]
                    else:
                        continue

                key = tuple(key.split("."))
                # tensor = from_bytes(None, buff)
                tensor = value.tolist()  # tensor -> List[float]
                if shard_fns is not None:
                    tensor = shard_fns[key](tensor)
                flattend_train_state[key] = tensor

        if target is not None:
            flattened_target = flatten_dict(
                to_state_dict(target), keep_empty_nodes=True
            )
            for key, value in flattened_target.items():
                if key not in flattend_train_state and value == empty_node:
                    flattend_train_state[key] = value
        train_state = unflatten_dict(flattend_train_state)
        if target is None:
            return train_state

        return from_state_dict(target, train_state)

From here I can at least have the weights saved, but I am not entirely sure if the weights are saved in the correct format, since I am running into other jax related issues(#18, but even more issues down the line). If you can confirm these weights format are correct that might help both of us down the line.