open-mmlab/mmengine

FSDPStrategy how to set mixed_precision and other params of pytorch

Opened this issue · 0 comments

📚 The doc issue

model is on CPU before input to FSDP

model = FSDP(model,
    auto_wrap_policy=t5_auto_wrap_policy,
    mixed_precision=mp_policy,
    #sharding_strategy=sharding_strategy,
    device_id=torch.cuda.current_device())

Suggest a potential alternative/fix

give more examples of FSDPStrategy
model_wrapper_cfg=dict(type='MMFullyShardedDataParallel', cpu_offload=True,use_orig_params=False)

指定 FSDPStrategy 并配置参数

size_based_auto_wrap_policy = partial(
size_based_auto_wrap_policy, min_num_params=1e7)
strategy = dict(
type='FSDPStrategy',
model_wrapper=dict(auto_wrap_policy=size_based_auto_wrap_policy))