bitsandbytes-foundation/bitsandbytes

Crash running FSDP on BF16-prequantized models

dmitrii-palisaderesearch opened this issue · 4 comments

System Info

$ python --version
Python 3.10.12

# pip install accelerate transformers bitsandbytes datasets trl peft setuptools
# using latest PyPI versions 
$ pip list
Package                  Version
------------------------ -----------
accelerate               0.33.0
aiohappyeyeballs         2.3.4
aiohttp                  3.10.1
aiosignal                1.3.1
async-timeout            4.0.3
attrs                    24.2.0
bitsandbytes             0.43.3
certifi                  2024.7.4
charset-normalizer       3.3.2
datasets                 2.20.0
dill                     0.3.8
docstring-parser         0.16
filelock                 3.15.4
frozenlist               1.4.1
fsspec                   2024.5.0
huggingface-hub          0.24.5
idna                     3.7
jinja2                   3.1.4
markdown-it-py           3.0.0
markupsafe               2.1.5
mdurl                    0.1.2
mpmath                   1.3.0
multidict                6.0.5
multiprocess             0.70.16
networkx                 3.3
numpy                    1.26.4
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.20.5
nvidia-nvjitlink-cu12    12.6.20
nvidia-nvtx-cu12         12.1.105
packaging                24.1
pandas                   2.2.2
peft                     0.12.0
psutil                   6.0.0
pyarrow                  17.0.0
pyarrow-hotfix           0.6
pygments                 2.18.0
python-dateutil          2.9.0.post0
pytz                     2024.1
pyyaml                   6.0.2
regex                    2024.7.24
requests                 2.32.3
rich                     13.7.1
safetensors              0.4.4
setuptools               72.1.0
shtab                    1.7.1
six                      1.16.0
sympy                    1.13.1
tokenizers               0.19.1
torch                    2.4.0
tqdm                     4.66.5
transformers             4.44.0
triton                   3.0.0
trl                      0.9.6
typing-extensions        4.12.2
tyro                     0.8.5
tzdata                   2024.1
urllib3                  2.2.2
xxhash                   3.4.1
yarl                     1.9.4

Reproduction

A DP run goes through fine:

accelerate launch --config-file dp.yml main.py
# OK!

A FSDP run crashes:

$ accelerate launch --config-file fsdp.yml main.py
FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/main.py", line 70, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 451, in train
[rank0]:     output = super().train(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 1948, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 2289, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3328, in training_step
[rank0]:     loss = self.compute_loss(model, inputs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/trainer.py", line 3373, in compute_loss
[rank0]:     outputs = model(**inputs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
[rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/peft/peft_model.py", line 1577, in forward
[rank0]:     return self.base_model(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 188, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
[rank0]:     outputs = self.model(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
[rank0]:     layer_outputs = decoder_layer(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 863, in forward
[rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 734, in forward
[rank0]:     hidden_states, self_attn_weights, present_key_value = self.self_attn(
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 617, in forward
[rank0]:     query_states = self.q_proj(hidden_states)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/peft/tuners/lora/bnb.py", line 467, in forward
[rank0]:     result = self.base_layer(x, *args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/ubuntu/repro-bnb-quant_state/.venv/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 477, in forward
[rank0]:     out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
[rank0]: AttributeError: 'Tensor' object has no attribute 'quant_state'

You can find the repro files in this gist.

Expected behavior

Post-#1295, running FSDP with models prequantized with BNB to NF4 stored in BF16 should work.

Thanks @dmitrii-palisaderesearch for raising this and giving detailed error logs and repro instructions. We (the bitsandbytes team) are under very tight bandwidth at the moment, so I can't guarantee a prompt response. Please keep us updated if anything changes.

Mentioning this to @matthewdouglas, as he the one recently dealing with FSDP and prequantized weights.

Hi @dmitrii-palisaderesearch, thank you for reporting!

This issue exists on the transformers side. We were not able to keep the required changes needed to support this ahead of the v4.40 release, but we should have it merged in soon. The PR to track for this is huggingface/transformers#32276.

Yes, I was tracking that, but then that was reverted at huggingface/transformers#32477 and I was confused. Thanks, I'll keep an eye on transformers.

Since the PR has been merged on the transformers side, I'm going to go ahead and close this.