Outdated utility function: No attribute get_module_class_from_name in FullyShardedDataParallelPlugin
Xirid opened this issue · 1 comments
Xirid commented
System Info
Accelerate & peft from main
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder - My own task or dataset (give details below)
Reproduction
Should happen with any FSDP+Accelerate+PEFT training that uses fsdp_auto_wrap_policy from peft.
This accelerate commit from 2 weeks ago moved get_module_class_from_name out from the class FullyShardedDataParallelPlugin.
So every time fsdp_auto_wrap_policy is called (for example in examples/conditional_generation/peft_lora_seq2seq_accelerate_fsdp.py) I get this:
Traceback (most recent call last):
File "/root/miniconda3/envs/finetuning/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/miniconda3/envs/finetuning/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/root/zeta/backend/peft_example.py", line 145, in <module>
main()
File "/root/zeta/backend/peft_example.py", line 82, in main
accelerator.state.fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(model)
File "/root/miniconda3/envs/finetuning/lib/python3.10/site-packages/peft/utils/other.py", line 406, in fsdp_auto_wrap_policy
transformer_cls = FullyShardedDataParallelPlugin.get_module_class_from_name(model, layer_class)
AttributeError: type object 'FullyShardedDataParallelPlugin' has no attribute 'get_module_class_from_name'
Expected behavior
Changing the import like this works(but is dependent on the accelerate version):
def fsdp_auto_wrap_policy(model):
import functools
import os
from accelerate.utils.dataclasses import get_module_class_from_name
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
default_transformer_cls_names_to_wrap = (
",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else ""
)
transformer_cls_names_to_wrap = os.environ.get(
"FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap
).split(",")
transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding}
for layer_class in transformer_cls_names_to_wrap:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=transformer_cls_to_wrap,
)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
return auto_wrap_policy
BenjaminBossan commented
Thanks for bringing this to our attention. The PR to fix this is already on its way: #1694