bghira/SimpleTuner

accelerate muti gpu with gradient_checkpointing throws an error

Closed this issue · 6 comments

In the trainer class currently in resume_and_prepare currently init_prepare_models(lr_scheduler=lr_scheduler) is called before init_post_load_freeze() this causes an issue with DDP because init_prepare_models is calling accelerator.prepare which is wrapping the unet/transformer as such init_post_load_freeze is trying to call enable_gradient_checkpointing on a wrapped class which throws an error.

A very quick fix is to call init_post_load_freeze and then call init_prepare_models which is what i am currently doing to run multi gpu trainings with accelerate

i can create a quick PR for the above but i haven't tested extensively. The above proposed fix works for FLUX dev LoRA training on 8 GPUS

are you talking about #686

umm no, it straight away throws a no attribute error, likely because the model class is getting wrapped under DistributedDataParallel

i think you can just put unwrap_model around it then

yeah that works as well, it's a minor issue not much pain, just thought will put it our here incase anyones is facing the same issue

currently most multigpu training is done using quantisation with PEFT being out of the options as a result of bug 686

so I guess it might be that you're the first one using PEFT with 8 GPUs, I've been using LyCORIS with 10x 3090.