Cannot restore FSDP checkpoint with LOCAL_STATE_DICT
helloworld1 opened this issue · 0 comments
helloworld1 commented
System Info
transformers
version: 4.40.1- Platform: Linux-5.15.148.2-2.cm2-x86_64-with-glibc2.35
- Python version: 3.10.2
- Huggingface_hub version: 0.23.0
- Safetensors version: 0.4.2
- Accelerate version: 0.29.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.1.2.1+gita8e7c98 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: Yes
- Using distributed or parallel set-up in script?: FSDP
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
I used FSDP with fsdp_state_dict_type = LOCAL_STATE_DICT
The accelerate config is like below
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: LOCAL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
main_training_function: main
mixed_precision: bf16
rdzv_backend: c10d
same_network: true
num_machines: 1
num_processes: 1
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
The checkpoint structure is like below
./trainer_state.json
./rng_state_1.pth
./pytorch_model_fsdp_rank1.bin
./pytorch_model_fsdp_rank0.bin
./pytorch_model_fsdp_rank4.bin
./rng_state_5.pth
./rng_state_4.pth
./rng_state_2.pth
./rng_state_3.pth
./pytorch_model_fsdp_rank6.bin
./rng_state_6.pth
./pytorch_model_fsdp_rank2.bin
./scheduler.pt
./rng_state_7.pth
./pytorch_model_fsdp_rank5.bin
./optimizer_0
./optimizer_0/__7_0.distcp
./optimizer_0/__1_0.distcp
./optimizer_0/.metadata
./optimizer_0/__3_0.distcp
./optimizer_0/__0_0.distcp
./optimizer_0/__4_0.distcp
./optimizer_0/__2_0.distcp
./optimizer_0/__6_0.distcp
./optimizer_0/__5_0.distcp
./pytorch_model_fsdp_rank3.bin
./pytorch_model_fsdp_rank7.bin
./rng_state_0.pth
When I try to restore the checkpoint from
trainer.train(resume_from_checkpoint="/home/user/checkpoint-10")
I got error
training.py 146 <module>
main()
training.py 125 main
train_results = trainer.train(resume_from_checkpoint=checkpoint)
sft_trainer.py 360 train
output = super().train(*args, **kwargs)
trainer.py 1859 train
return inner_training_loop(
trainer.py 2037 _inner_training_loop
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
trainer.py 2431 _load_from_checkpoint
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
ValueError:
Can't find a valid checkpoint at /home/user/checkpoint-10
If I used SHARDED_STATE_DICT, I don't have this error.
Expected behavior
Expect the checkpoint can be restored