OpenMOSS/CoLLiE

save_16bit_model does not save the proper state_dict

Closed this issue · 11 comments

Dear authors,

On V100, the torch.cat implementation of saving state_dict is memory demanding and could cause OOM for LLaMA 7B when gathering weights to a single GPU, so I am trying to save the state_dict using trainer.engine.save_16bit_model. However, it seems that loading the state_dict via the standard from_pretrained interface of huggingface raises ValueError: The state dictionary of the model you are trying to load is corrupted. Are you sure it was properly saved? Does save_parallel_state_dict in the collie implementation of LLaMA use the same set of state_dict keys as those in the huggingface implementation of save_pretrained (of LLaMA)?

  1. We're sorry that save_16bit_state is not supported yet since collie have to process state dict before saving to fit huggingface format. Now we are on the progress of finding a solution to avoid OOM when using torch.cat
  2. save_parallel_state_dict uses the same set of state_dict keys as those in the huggingface implementation of save_pretrained, you can load with huggingface's from_pretrained from the state dict saved by save_parallel_state_dict.

What if saving with process_exclusion=True?

What if saving with process_exclusion=True?

I can not understand this question easily. According to my understanding, process_exclusion is an option for CPU-memory-saving.

  1. We're sorry that save_16bit_state is not supported yet since collie have to process state dict before saving to fit huggingface format. Now we are on the progress of finding a solution to avoid OOM when using torch.cat
  2. save_parallel_state_dict uses the same set of state_dict keys as those in the huggingface implementation of save_pretrained, you can load with huggingface's from_pretrained from the state dict saved by save_parallel_state_dict.

I find that the keys used by save_parallel_state_dict do not have the model. prefix compared with the huggingface implementation (except for the lm_head.weight key, which also does not have the model. prefix in the hugginface implementation.

What if saving with process_exclusion=True?

I can not understand this question easily. According to my understanding, process_exclusion is an option for CPU-memory-saving.

BTY, process_exclusion=True can not alleviate the OOM issue according to my observation.

  1. We're sorry that save_16bit_state is not supported yet since collie have to process state dict before saving to fit huggingface format. Now we are on the progress of finding a solution to avoid OOM when using torch.cat
  2. save_parallel_state_dict uses the same set of state_dict keys as those in the huggingface implementation of save_pretrained, you can load with huggingface's from_pretrained from the state dict saved by save_parallel_state_dict.

I find that the keys used by save_parallel_state_dict do not have the model. prefix compared with the huggingface implementation (except for the lm_head.weight key, which also does not have the model. prefix in the hugginface implementation.

Sorry that the previous reply is not precise enough.

  • Collie's llama structure is different from huggingface, so checkpoint saved by save_16bit_model cannot be loaded directly.
  • save_parallel_state_dict has a parameter state_dict which is usually gained through model.state_dict(), then the function will process this state dict (merge splited weights and rename keys) before saving. So the checkpoint saved by save_parallel_state_dict can be loaded using huggingface from_pretrained.
  • Vice versa, load_parallel_state_dict will process the checkpoint(split weights and rename keys) to fit collie's model, thus collie's model is able to load checkpoint from huggingface.

We are now trying to fix OOM caused by torch.cat, which may take some time.

I check the keys in the state_dict saved by save_16bit_model and manually add a model. prefix to all the keys except the lm_head.weight. Then I could load the finetuned state_dict with the usual from_pretrained in huggingface, but the test performance is extremely bad, even worse than the vanilla LLaMA-7B model. Is there any clue for this phenomenon? Or the items in the state_dict of Collie's llama and hf's llama implementation just happen to have keys with common suffix?

I check the keys in the state_dict saved by save_16bit_model and manually add a model. prefix to all the keys except the lm_head.weight. Then I could load the finetuned state_dict with the usual from_pretrained in huggingface, but the test performance is extremely bad, even worse than the vanilla LLaMA-7B model. Is there any clue for this phenomenon? Or the items in the state_dict of Collie's llama and hf's llama implementation just happen to have keys with common suffix?

What's your pp_size tp_size and dp_size? In some situation some weights may be splited(zero3 or tensor parallelism) and cannot be loaded correctly. That's why save_parallel_state_dict is necessary.

On 8 * V100 (standalone setting of torchrun), config. dp_size is set to 8 and other two are left as default. I also wonder whether these 3 params can have effect when using ZeRO-3 by deepspeed?

  • zero-3 is incompatible with pipeline parallelsim and the rest 2 params are ok.
  • That's a bit wired since data parallelism has nothing to do with torch.cat and the weights saved should be correct in theory...

Can you try branch dev to run your code? OOM is fixed on this branch and you can save with CheckpointCallback or trainer.save_model