huggingface/transformers

Cannot restore FSDP checkpoint with LOCAL_STATE_DICT

helloworld1 opened this issue · 0 comments

System Info

  • transformers version: 4.40.1
  • Platform: Linux-5.15.148.2-2.cm2-x86_64-with-glibc2.35
  • Python version: 3.10.2
  • Huggingface_hub version: 0.23.0
  • Safetensors version: 0.4.2
  • Accelerate version: 0.29.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2.1+gita8e7c98 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: FSDP

Who can help?

@pacman100 @muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I used FSDP with fsdp_state_dict_type = LOCAL_STATE_DICT
The accelerate config is like below

compute_environment: LOCAL_MACHINE                                                                                                  
debug: false                                                                                                                        
distributed_type: FSDP                                                                                                              
downcast_bf16: 'no'                                                                                                                 
fsdp_config:                                                                                                                        
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP                                                                                     
  fsdp_backward_prefetch: BACKWARD_PRE                                                                                              
  fsdp_cpu_ram_efficient_loading: true                                                                                              
  fsdp_forward_prefetch: false                                                                                                      
  fsdp_offload_params: false                                                                                                        
  fsdp_sharding_strategy: FULL_SHARD                                                                                                
  fsdp_state_dict_type: LOCAL_STATE_DICT                                                                                            
  fsdp_sync_module_states: true                                                                                                     
  fsdp_use_orig_params: true                                                                                                        
main_training_function: main                                                                                                        
mixed_precision: bf16                                                                                                               
rdzv_backend: c10d                                                                                                                  
same_network: true                                                                                                                  
num_machines: 1                                                                                                                     
num_processes: 1                                                                                                                    
tpu_env: []                                                                                                                         
tpu_use_cluster: false                                                                                                              
tpu_use_sudo: false                                                                                                                 
use_cpu: false   

The checkpoint structure is like below

./trainer_state.json
./rng_state_1.pth
./pytorch_model_fsdp_rank1.bin
./pytorch_model_fsdp_rank0.bin
./pytorch_model_fsdp_rank4.bin
./rng_state_5.pth
./rng_state_4.pth
./rng_state_2.pth
./rng_state_3.pth
./pytorch_model_fsdp_rank6.bin
./rng_state_6.pth
./pytorch_model_fsdp_rank2.bin
./scheduler.pt
./rng_state_7.pth
./pytorch_model_fsdp_rank5.bin
./optimizer_0
./optimizer_0/__7_0.distcp
./optimizer_0/__1_0.distcp
./optimizer_0/.metadata
./optimizer_0/__3_0.distcp
./optimizer_0/__0_0.distcp
./optimizer_0/__4_0.distcp
./optimizer_0/__2_0.distcp
./optimizer_0/__6_0.distcp
./optimizer_0/__5_0.distcp
./pytorch_model_fsdp_rank3.bin
./pytorch_model_fsdp_rank7.bin
./rng_state_0.pth

When I try to restore the checkpoint from

trainer.train(resume_from_checkpoint="/home/user/checkpoint-10") 

I got error

training.py 146 <module>     
main()                                                                                                                              
                                                                                                                                    
training.py 125 main                                                                                                                
train_results = trainer.train(resume_from_checkpoint=checkpoint)                                                                    
                                                                                                                                    
sft_trainer.py 360 train                                          
output = super().train(*args, **kwargs)                                                                                             
                                                                  
trainer.py 1859 train                                                                                                               
return inner_training_loop(                                                                                                         
                                                                  
trainer.py 2037 _inner_training_loop                              
self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)
                                                                  
trainer.py 2431 _load_from_checkpoint
raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")                                                      
                                 
ValueError:                                                                                                                         
Can't find a valid checkpoint at /home/user/checkpoint-10  

If I used SHARDED_STATE_DICT, I don't have this error.

Expected behavior

Expect the checkpoint can be restored