Lightning-AI/litgpt

Lora recipes use lots of memory because of not wrapping parameters with gradient in separate FSDP unit

RuABraun opened this issue · 2 comments

As mentioned as on pytorch FSDP page, all parameters inside an FSDP unit will use memory as if they all required a gradient if some of them do:

FSDP has some constraints on freezing parameters (i.e. setting param.requires_grad=False). For use_orig_params=False, each FSDP instance must manage parameters that are all frozen or all non-frozen. For use_orig_params=True, FSDP supports mixing frozen and non-frozen, but we recommend not doing so since then the gradient memory usage will be higher than expected (namely, equivalent to not freezing those parameters). This means that ideally, frozen parameters should be isolated into their own nn.Module s and wrapped separately with FSDP.

It seems to me like the lora recipe doesn't take that into account, see code

The right way to do so seems to be as shown in the torchtune repo.

Is my understanding correct? Is this an area for improvement of litgpt? It seems to me like doing it correctly is tricky as litgpt does not separate the Lora parameters into a separate module.

I believe you are correct. I opened the same issue in #1392

Ah I should have searched more. I'll close this