Crash running FSDP on BF16-prequantized models
dmitrii-palisaderesearch opened this issue · 4 comments
System Info
$ python --version
Python 3.10.12
# pip install accelerate transformers bitsandbytes datasets trl peft setuptools
# using latest PyPI versions
$ pip list
Package Version
------------------------ -----------
accelerate 0.33.0
aiohappyeyeballs 2.3.4
aiohttp 3.10.1
aiosignal 1.3.1
async-timeout 4.0.3
attrs 24.2.0
bitsandbytes 0.43.3
certifi 2024.7.4
charset-normalizer 3.3.2
datasets 2.20.0
dill 0.3.8
docstring-parser 0.16
filelock 3.15.4
frozenlist 1.4.1
fsspec 2024.5.0
huggingface-hub 0.24.5
idna 3.7
jinja2 3.1.4
markdown-it-py 3.0.0
markupsafe 2.1.5
mdurl 0.1.2
mpmath 1.3.0
multidict 6.0.5
multiprocess 0.70.16
networkx 3.3
numpy 1.26.4
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu12 2.20.5
nvidia-nvjitlink-cu12 12.6.20
nvidia-nvtx-cu12 12.1.105
packaging 24.1
pandas 2.2.2
peft 0.12.0
psutil 6.0.0
pyarrow 17.0.0
pyarrow-hotfix 0.6
pygments 2.18.0
python-dateutil 2.9.0.post0
pytz 2024.1
pyyaml 6.0.2
regex 2024.7.24
requests 2.32.3
rich 13.7.1
safetensors 0.4.4
setuptools 72.1.0
shtab 1.7.1
six 1.16.0
sympy 1.13.1
tokenizers 0.19.1
torch 2.4.0
tqdm 4.66.5
transformers 4.44.0
triton 3.0.0
trl 0.9.6
typing-extensions 4.12.2
tyro 0.8.5
tzdata 2024.1
urllib3 2.2.2
xxhash 3.4.1
yarl 1.9.4
Reproduction
A DP run goes through fine:
accelerate launch --config-file dp.yml main.py
# OK!
A FSDP run crashes:
$ accelerate launch --config-file fsdp.yml main.py
FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/main.py", line 70, in <module>
[rank0]: trainer.train()
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 451, in train
[rank0]: output = super().train(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 1948, in train
[rank0]: return inner_training_loop(
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2289, in _inner_training_loop
[rank0]: tr_loss_step = self.training_step(model, inputs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3328, in training_step
[rank0]: loss = self.compute_loss(model, inputs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3373, in compute_loss
[rank0]: outputs = model(**inputs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]: return model_forward(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
[rank0]: return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]: return func(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]: return model_forward(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
[rank0]: return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]: return func(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward
[rank0]: return self.base_model(
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
[rank0]: return self.model.forward(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
[rank0]: outputs = self.model(
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
[rank0]: layer_outputs = decoder_layer(
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
[rank0]: hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 617, in forward
[rank0]: query_states = self.q_proj(hidden_states)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/peft/tuners/lora/bnb.py", line 467, in forward
[rank0]: result = self.base_layer(x, *args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 477, in forward
[rank0]: out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
[rank0]: AttributeError: 'Tensor' object has no attribute 'quant_state'
You can find the repro files in this gist.
Expected behavior
Post-#1295, running FSDP with models prequantized with BNB to NF4 stored in BF16 should work.
Thanks @dmitrii-palisaderesearch for raising this and giving detailed error logs and repro instructions. We (the bitsandbytes team) are under very tight bandwidth at the moment, so I can't guarantee a prompt response. Please keep us updated if anything changes.
Mentioning this to @matthewdouglas, as he the one recently dealing with FSDP and prequantized weights.
Hi @dmitrii-palisaderesearch, thank you for reporting!
This issue exists on the transformers
side. We were not able to keep the required changes needed to support this ahead of the v4.40 release, but we should have it merged in soon. The PR to track for this is huggingface/transformers#32276.
Yes, I was tracking that, but then that was reverted at huggingface/transformers#32477 and I was confused. Thanks, I'll keep an eye on transformers
.
Since the PR has been merged on the transformers side, I'm going to go ahead and close this.