NVIDIA/Megatron-LM

[QUESTION] vicuna-7b-v1.5 weight conversion from huggingface to megatron-lm format

Opened this issue · 4 comments

I am trying to convert the weight for vicuna-7b-v1.5 in huggingface transformers ( https://huggingface.co/lmsys/vicuna-7b-v1.5 ) to be used with megatron-lm.
I am using tools/checkpoint/convert.py to do the conversion.
The command I used is as follows:

python tools/checkpoint/convert.py \
  --model-type GPT \
  --loader llama2_hf \
  --saver megatron \
  --target-tensor-parallel-size 2 \
  --target-pipeline-parallel-size 2 \
  --load-dir ${HF_CHECKPOINT_DIR} \
  --save-dir ${MEGATRON_CHECKPOINT_DIR} \
  --tokenizer-model ${TOKENIZER_MODEL}

When I run it, I get an error like this:

Traceback (most recent call last):
  File "[...]/Megatron-LM/tools/checkpoint/convert.py", line 158, in <module>
    main()
  File "[...]/Megatron-LM/tools/checkpoint/convert.py", line 151, in main
    loader.load_checkpoint(queue, args)
  File "[...]/Megatron-LM/tools/checkpoint/loader_llama2_hf.py", line 370, in load_checkpoint
    _load_checkpoint(queue, args)
  File "[...]/Megatron-LM/tools/checkpoint/loader_llama2_hf.py", line 280, in _load_checkpoint
    model = load_checkpoint_to_model(margs)
  File "[...]/Megatron-LM/tools/checkpoint/loader_llama2_hf.py", line 140, in load_checkpoint_to_model
    model = model_provider(True, True).to(args.params_dtype)
  File "[...]/Megatron-LM/pretrain_gpt.py", line 84, in model_provider
    model = megatron.legacy.model.GPTModel(
  File "[...]/Megatron-LM/megatron/legacy/model/gpt_model.py", line 61, in __init__
    self.language_model, self._language_model_key = get_language_model(
  File "[...]/Megatron-LM/megatron/legacy/model/language_model.py", line 67, in get_language_model
    language_model = TransformerLanguageModel(
  File "[...]/Megatron-LM/megatron/legacy/model/language_model.py", line 387, in __init__
    self.encoder = ParallelTransformer(
  File "[...]/Megatron-LM/megatron/legacy/model/transformer.py", line 1579, in __init__
    [build_layer(i + 1 + offset) for i in range(self.num_layers)])
  File "[...]/Megatron-LM/megatron/legacy/model/transformer.py", line 1579, in <listcomp>
    [build_layer(i + 1 + offset) for i in range(self.num_layers)])
  File "[...]/Megatron-LM/megatron/legacy/model/transformer.py", line 1519, in build_layer
    tp_group=mpu.get_tensor_model_parallel_group(),
  File "[...]/Megatron-LM/megatron/core/parallel_state.py", line 567, in get_tensor_model_parallel_group
    assert (
AssertionError: tensor model parallel group is not initialized

I looked into it, and it seems this error happens here:

def get_tensor_model_parallel_group(check_initialized=True):
"""Get the tensor model parallel group the caller rank belongs to."""
if check_initialized:
assert (
_TENSOR_MODEL_PARALLEL_GROUP is not None
), 'tensor model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP

because _TENSOR_MODEL_PARALLEL_GROUP does not have a value set.

However, I found that _TENSOR_MODEL_PARALLEL_GROUP is only set here in the whole code:

_TENSOR_MODEL_PARALLEL_GROUP = group

and this function initialize_model_parallel does not seem to be called during the weight conversion.

How can I correctly do the weight conversion?

I'm also interested in this, and more generally how Megatron can be used to convert from HF, continue pretraining, and convert back to HF.

same issue on different model

My understanding is that megatron model_type (that uses transformer-impl=local) is deprecated. Consider using mcore model_type (uses transformer-impl=transformer_engine):

--saver mcore

Also, if you do need megatron model_type, try saving first to mcore, then to megatron. Last time I checked, that worked.