When Precision Meets Position: ⚠️ BFloat16 Breaks Down RoPE in Long-Context Training 📄 arXiv
RoPE is Broken 🤯 because of ... BFloat16 — Then, How Can We Improve Long-Context Training? 🤔️
👉 We propose AnchorAttention, the improved attention for LLMs long-context training.
This repository provides the implementation of AnchorAttention
, a plug-and-play attention mechanism designed to significantly improve the long-context training of Large Language Models (LLMs). AnchorAttention addresses numerical issues that arise when using Rotary Positional Embedding (RoPE) with BFloat16 precision, particularly in long-context scenarios.
- Improved Long-Context Performance: Improves model's long-context performance on RULER and LongBench.
- Efficient Training: Reduces training time by over 50% compared to standard full attention mechanisms.
- Compatibility: Supports several popular models and integrates with
FlashAttention2
andFlexAttention
engines for optimized computations. - Semantic Coherence: Maintains the original LLM's capabilities on general tasks while improving performance on long-context tasks.
AnchorAttention treats the first token in the sequence as a shared anchor with a consistent position ID. This approach:
- Reduces unnecessary attention computations by focusing on the anchor token.
- Ensures that all documents within the training context have visibility to the anchor, enhancing information flow.
- Mitigates the numerical issues caused by limited precision in BFloat16 format.
The code is tested on Python 3.10.0
, PyTorch 2.5.0.dev20240912+cu121
/ 2.6.0.dev20241112+cu121
, and CUDA 12.1
.
conda create -n anchorcontext python=3.10 -y && conda activate anchorcontext
pip install torch==2.6.0.dev20241112+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121
pip install packaging && pip install ninja && pip install flash-attn==2.5.6 --no-build-isolation --no-cache-dir
pip install yunchang==0.2 # directly ```pip install yunchang``` if you are using flash_attn >= 2.6.0
git clone https://github.com/haonan3/AnchorContext.git
cd AnchorContext
pip install -r requirements.txt
python anchor_context/library_modified_files/softlink_to_library.py
For our Flexattention experiment presented in the paper, we originally used PyTorch version 2.5.0.dev20240912+cu121
. However, since this version is now archived and no longer directly available, we recommend using 2.6.0.dev20241112+cu121
instead.
-
Archived Version (
2.5.0.dev20240912+cu121
):- Used in our paper for the FLEX experiment.
- Currently archived and not directly accessible.
-
Recommended Development Version (
2.6.0.dev20241112+cu121
):- Serves as a suitable alternative to the archived version.
- Ensures compatibility and optimal performance without additional configuration.
-
Stable Versions (
2.5.0
&2.5.1
):- Widely used and stable.
- Require calling
torch._dynamo.reset()
, which may slightly reduce performance.
Refer to the demo_llama.sh
script for a training example.
Our code currently supports two types of input data:
-
Raw JSON Files
JSON files formatted as{ "text": "xxxxx" }
.
To use raw JSON files as input, include the flag--data-format raw_json
. -
Pre-tokenized Data
Outputs from the FranxYao/Long-Context-Data-Engineering repository.
If you are using the pre-tokenized output from this repository, include the flag--data-format tokenized
.
To facilitate your development process, we provide a debug tool that can be easily integrated into your other projects. Debugging DeepSpeed code using VSCode can be challenging, so we have developed a tool included in the util.py
file.
-
Enable Debugging:
Add the following line to
demo_llama.sh
to enable the debug tool:export DEBUGPY=1
-
Setup Debugging in
train.py
:Locate the
setup_debugpy
function call intrain.py
to initialize the debugger. -
Configure VSCode:
- Ensure your
launch.json
is set up to use the Attach mode. You can use.vscode/launch.json
as a demo.
- Attach the Debugger in VSCode:
-
Run
bash demo_llama.sh
. -
Open VSCode and go to the Run and Debug panel.
-
Select the Attach configuration from your
launch.json
. -
Start the debugger to connect to the specified
endpoint
andport
.(Once attached, you can set breakpoints, inspect variables, and step through the code as needed.)
Below is the setup_debugpy
function from util.py
that sets up the debug server:
def setup_debugpy(accelerator, endpoint="localhost", port=5678, rank=0, force=False):
if "DEBUGPY" not in os.environ:
print(colored("DEBUGPY not in os.environ", "red"))
return
rank = int(os.getenv("DEBUGPY_RANK", rank))
port = int(os.getenv("DEBUGPY_PORT", port))
endpoint = os.getenv("DEBUGPY_ENDPOINT", endpoint)
if accelerator.process_index != rank:
accelerator.wait_for_everyone()
return
if force:
print(colored("Force killed debugpy", "red"))
try:
debugpy.listen((endpoint, port))
print(colored(f"Waiting for debugger attach on {endpoint}:{port}", "red"))
debugpy.wait_for_client()
except:
print(colored(f"Failed to setup debugpy, {endpoint}:{port} occupied", "red"))
accelerator.wait_for_everyone()
- Port Configuration: Ensure that the port specified (default is
5678
) is open and not used by other applications. - Rank Configuration: The
rank
parameter ensures that the debugger attaches only to the specified process in distributed training scenarios.
-
Test Multiple Nodes
- Test distributed training across multiple nodes to enhance scalability.
-
Enable and Test Larger Batch Sizes
- Enable and then test batch sizes greater than 1 to improve training efficiency.
We currently do not have a clear timeline for these tasks. Community contributions and collaborations are highly encouraged! If you're interested in contributing, please feel free to open an issue or submit a pull request.
If you need support for models beyond the currently supported LLaMA, Mistral, and Qwen series, please reach out to me directly. Haonan Wang haonan.wang@u.nus.edu, WeChat (whn769452159).
I am happy to provide infrastructure support for more models!
We have verified the effectiveness of AnchorAttention for long-context extension. However, we believe that AnchorAttention can also be highly effective for LLM pretraining. This is because RoPE and BFloat16 are also utilized in pre-trained models with 4K or 8K window sizes. 🚀
While we can achieve an equivalent computation of AnchorAttention using FlashAttention2, making support for FlexAttention seem less critical, we believe that FlexAttention—supported in PyTorch 2.5.0 and later—will be the future. Not only is it fast, but it also supports advanced attention masking techniques, such as those introduced in Transfusion.
In other words, we think our codebase can be easily adapted to improve text-and-image LLM models or text-and-video LLMs.
This work is built on top of the following papers/repositories:
- Flex-Attention
- Flash-Attention
- EasyContext
- Yunchang
- YaRN
- Long-Context Data Engineering
- Ring-Flash-Attention
- Ring-Attention (Shenggui et al.; Liu et al.)
If you find this work useful, please consider citing it:
@misc{wang2024precisionmeetspositionbfloat16,
title={When Precision Meets Position: BFloat16 Breaks Down RoPE in Long-Context Training},
author={Haonan Wang and Qian Liu and Chao Du and Tongyao Zhu and Cunxiao Du and Kenji Kawaguchi and Tianyu Pang},
year={2024},
eprint={2411.13476},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2411.13476},
}
😝 And don’t forget to give it a ⭐ if you like it!