bitsandbytes-foundation/bitsandbytes

unable to run FSDP2 with low bit optimizers like adam 8 bit

nighting0le01 opened this issue · 0 comments

Feature request

Traceback (most recent call last):
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/ubuntu/projects/scripts/train_ranker.py", line 404, in <module>
    train(accelerator, args)
  File "/home/ubuntu/projects/scripts/train_ranker.py", line 357, in train
    save_training_artifacts(
  File "/home/ubuntu/projects/scripts/train_ranker.py", line 74, in save_training_artifacts
    accelerator.save_state(save_dir)
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/accelerate/accelerator.py", line 2958, in save_state
    save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 168, in save_fsdp_optimizer
    optim_state = FSDP.optim_state_dict(model, optimizer)
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1840, in optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_impl(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1263, in _optim_state_dict_impl
    return _optim_state_dict(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1971, in _optim_state_dict
    fsdp_osd_state = convert_fn(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1794, in _convert_state_with_orig_params
    _gather_all_orig_param_state(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1688, in _gather_all_orig_param_state
    output_states = _allgather_orig_param_states(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1518, in _allgather_orig_param_states
    dtype, state_buffers = _convert_all_state_info(
  File "/home/ubuntu/miniconda3/envs/python-3.10/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1377, in _convert_all_state_info
    assert dtype == info.dtype
AssertionError

Motivation

to use FSDP2 and FSDP1 with low bit optimizers together.

Your contribution

please let me know