save_16bit_model does not save the proper state_dict
Closed this issue · 11 comments
Dear authors,
On V100, the torch.cat implementation of saving state_dict is memory demanding and could cause OOM for LLaMA 7B when gathering weights to a single GPU, so I am trying to save the state_dict using trainer.engine.save_16bit_model
. However, it seems that loading the state_dict via the standard from_pretrained
interface of huggingface raises ValueError: The state dictionary of the model you are trying to load is corrupted. Are you sure it was properly saved?
Does save_parallel_state_dict
in the collie implementation of LLaMA use the same set of state_dict keys as those in the huggingface implementation of save_pretrained
(of LLaMA)?
- We're sorry that
save_16bit_state
is not supported yet since collie have to process state dict before saving to fit huggingface format. Now we are on the progress of finding a solution to avoid OOM when usingtorch.cat
save_parallel_state_dict
uses the same set of state_dict keys as those in the huggingface implementation of save_pretrained, you can load with huggingface'sfrom_pretrained
from the state dict saved bysave_parallel_state_dict
.
What if saving with process_exclusion=True
?
What if saving with
process_exclusion=True
?
I can not understand this question easily. According to my understanding, process_exclusion
is an option for CPU-memory-saving.
- We're sorry that
save_16bit_state
is not supported yet since collie have to process state dict before saving to fit huggingface format. Now we are on the progress of finding a solution to avoid OOM when usingtorch.cat
save_parallel_state_dict
uses the same set of state_dict keys as those in the huggingface implementation of save_pretrained, you can load with huggingface'sfrom_pretrained
from the state dict saved bysave_parallel_state_dict
.
I find that the keys used by save_parallel_state_dict
do not have the model.
prefix compared with the huggingface implementation (except for the lm_head.weight
key, which also does not have the model.
prefix in the hugginface implementation.
What if saving with
process_exclusion=True
?I can not understand this question easily. According to my understanding,
process_exclusion
is an option for CPU-memory-saving.
BTY, process_exclusion=True
can not alleviate the OOM issue according to my observation.
- We're sorry that
save_16bit_state
is not supported yet since collie have to process state dict before saving to fit huggingface format. Now we are on the progress of finding a solution to avoid OOM when usingtorch.cat
save_parallel_state_dict
uses the same set of state_dict keys as those in the huggingface implementation of save_pretrained, you can load with huggingface'sfrom_pretrained
from the state dict saved bysave_parallel_state_dict
.I find that the keys used by
save_parallel_state_dict
do not have themodel.
prefix compared with the huggingface implementation (except for thelm_head.weight
key, which also does not have themodel.
prefix in the hugginface implementation.
Sorry that the previous reply is not precise enough.
- Collie's llama structure is different from huggingface, so checkpoint saved by
save_16bit_model
cannot be loaded directly. save_parallel_state_dict
has a parameterstate_dict
which is usually gained throughmodel.state_dict()
, then the function will process this state dict (merge splited weights and rename keys) before saving. So the checkpoint saved bysave_parallel_state_dict
can be loaded using huggingfacefrom_pretrained
.- Vice versa,
load_parallel_state_dict
will process the checkpoint(split weights and rename keys) to fit collie's model, thus collie's model is able to load checkpoint from huggingface.
We are now trying to fix OOM caused by torch.cat
, which may take some time.
I check the keys in the state_dict saved by save_16bit_model
and manually add a model.
prefix to all the keys except the lm_head.weight
. Then I could load the finetuned state_dict with the usual from_pretrained
in huggingface, but the test performance is extremely bad, even worse than the vanilla LLaMA-7B model. Is there any clue for this phenomenon? Or the items in the state_dict of Collie's llama and hf's llama implementation just happen to have keys with common suffix?
I check the keys in the state_dict saved by
save_16bit_model
and manually add amodel.
prefix to all the keys except thelm_head.weight
. Then I could load the finetuned state_dict with the usualfrom_pretrained
in huggingface, but the test performance is extremely bad, even worse than the vanilla LLaMA-7B model. Is there any clue for this phenomenon? Or the items in the state_dict of Collie's llama and hf's llama implementation just happen to have keys with common suffix?
What's your pp_size tp_size and dp_size? In some situation some weights may be splited(zero3 or tensor parallelism) and cannot be loaded correctly. That's why save_parallel_state_dict
is necessary.
On 8 * V100 (standalone setting of torchrun), config. dp_size
is set to 8 and other two are left as default. I also wonder whether these 3 params can have effect when using ZeRO-3 by deepspeed?
- zero-3 is incompatible with pipeline parallelsim and the rest 2 params are ok.
- That's a bit wired since data parallelism has nothing to do with
torch.cat
and the weights saved should be correct in theory...
Can you try branch dev
to run your code? OOM is fixed on this branch and you can save with CheckpointCallback
or trainer.save_model