unable to run FSDP2 with low bit optimizers like adam 8 bit
nighting0le01 opened this issue · 0 comments
nighting0le01 commented
Feature request
Traceback (most recent call last):
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/ubuntu/projects/scripts/train_ranker.py", line 404, in <module>
train(accelerator, args)
File "/home/ubuntu/projects/scripts/train_ranker.py", line 357, in train
save_training_artifacts(
File "/home/ubuntu/projects/scripts/train_ranker.py", line 74, in save_training_artifacts
accelerator.save_state(save_dir)
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 2958, in save_state
save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 168, in save_fsdp_optimizer
optim_state = FSDP.optim_state_dict(model, optimizer)
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1840, in optim_state_dict
return FullyShardedDataParallel._optim_state_dict_impl(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1263, in _optim_state_dict_impl
return _optim_state_dict(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1971, in _optim_state_dict
fsdp_osd_state = convert_fn(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1794, in _convert_state_with_orig_params
_gather_all_orig_param_state(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1688, in _gather_all_orig_param_state
output_states = _allgather_orig_param_states(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1518, in _allgather_orig_param_states
dtype, state_buffers = _convert_all_state_info(
File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1377, in _convert_all_state_info
assert dtype == info.dtype
AssertionError
Motivation
to use FSDP2 and FSDP1 with low bit optimizers together.
Your contribution
please let me know