accelerate muti gpu with gradient_checkpointing throws an error
Closed this issue · 6 comments
In the trainer class currently in resume_and_prepare
currently init_prepare_models(lr_scheduler=lr_scheduler)
is called before init_post_load_freeze()
this causes an issue with DDP because init_prepare_models
is calling accelerator.prepare
which is wrapping the unet/transformer as such init_post_load_freeze
is trying to call enable_gradient_checkpointing
on a wrapped class which throws an error.
A very quick fix is to call init_post_load_freeze
and then call init_prepare_models
which is what i am currently doing to run multi gpu trainings with accelerate
i can create a quick PR for the above but i haven't tested extensively. The above proposed fix works for FLUX dev LoRA training on 8 GPUS
umm no, it straight away throws a no attribute error, likely because the model class is getting wrapped under DistributedDataParallel
i think you can just put unwrap_model around it then
yeah that works as well, it's a minor issue not much pain, just thought will put it our here incase anyones is facing the same issue
currently most multigpu training is done using quantisation with PEFT being out of the options as a result of bug 686
so I guess it might be that you're the first one using PEFT with 8 GPUs, I've been using LyCORIS with 10x 3090.