This implementation of LongNet is brought to you by Agora, we're an all-new open source AI research organization with 1,500+ AI researchers all striving to advance Humanity!
Join us and help contribute to LongNet and or recieve FAST support in the Agora discord!
This is an open source implementation for the paper LongNet: Scaling Transformers to 1,000,000,000 Tokens by Jiayu Ding, Shuming Ma, Li Dong, Xingxing Zhang, Shaohan Huang, Wenhui Wang, Furu Wei. The LongNet is a Transformer variant designed to scale sequence length up to more than 1 billion tokens without sacrificing performance on shorter sequences.
Scaling sequence length has become a critical bottleneck in the era of large language models. However, existing methods struggle with either computational complexity or model expressivity, rendering the maximum sequence length restricted. In this paper, they introduce LongNet, a Transformer variant that can scale sequence length to more than 1 billion tokens, without sacrificing the performance on shorter sequences. Specifically, they propose dilated attention, which expands the attentive field exponentially as the distance grows.
LongNet has significant advantages:
- It has a linear computation complexity and a logarithm dependency between tokens.
- It can be served as a distributed trainer for extremely long sequences.
- Its dilated attention is a drop-in replacement for standard attention, which can be seamlessly integrated with the existing Transformer-based optimization.
Experiment results demonstrate that LongNet yields strong performance on both long-sequence modeling and general language tasks. Their work opens up new possibilities for modeling very long sequences, e.g., treating a whole corpus or even the entire Internet as a sequence.
Here's the updated usage and installation section with two methods: git clone or pip install LongNet:
You can install LongNet using one of the following methods:
- Clone the LongNet repository from GitHub:
git clone https://github.com/kyegomez/LongNet.git
- Navigate to the cloned directory:
cd LongNet
- Install the required dependencies:
pip install -r requirements.txt
- Note that pip install does not work as the
flash-attn
library cannot be compiled since it has custom CUDA Kernels and they need to be built manually.
- Install LongNet directly from PyPI using pip:
pip install LongNet
Please note that LongNet requires a compatible Python version (tested with Python 3.7).
Once you have installed LongNet, you can use the DilatedAttention
class as follows:
import torch
import torch.nn as nn
from LongNet import DilatedAttention
# Replace this with your correct GPU device
device = "cuda:0"
dtype = torch.float16
# Create an instance of DilatedAttention
d_model = 512
num_heads = 8
dilation_rate = 2
segment_size = 64
dropout = 0.2 # Specify the dropout rate
attention = DilatedAttention(
d_model=d_model,
num_heads=num_heads,
dilation_rate=dilation_rate,
segment_size=segment_size,
dropout=dropout,
).to(device, dtype=dtype)
# Create some dummy input data
batch_size = 16
seq_len = 128
input_dim = d_model
inputs = torch.randn(batch_size, seq_len, input_dim, device=device, dtype=dtype)
# Forward pass
outputs = attention(inputs)
# Print the output shape
print(outputs.shape) # Expected: [batch_size, seq_len, d_model]
- We're still working on the model configuation as closely in the paper as possible. There are 2 methods, one is
accelerate
and the otherfrom LongNet import Train
-
Git clone installation
-
Init your parameters
accelerate config
-
Then
accelerate launch LongNet/training.py
- Pip install method
from LongNet import Train
Train()
In the example above, we create an instance of the DilatedAttention
class with the specified hyperparameters. We then generate some dummy input data and pass it through the attention mechanism to obtain the outputs. Finally, we print the shape of the output tensor.
1. Initialize the input (Q, K, V) and split them into segments {(Qei, Kei, Vei)} with equal segment length w.
2. Sparsify each segment along the sequence dimension by selecting the rows with an interval r.
3. Feed the sparsified segments into the attention in parallel.
4. Scatter and concatenate the output O from the attention.
5. Implement a mixture of dilated attentions with different segment sizes and dilation rates {ri, wi}.
6. For multi-head dilated attention, differ the computation among different heads by sparsifying different parts of the query-key-value pairs.
7. Concatenate the outputs of different heads into a final output.
class DilatedAttention(nn.Module):
def __init__(self, d_model, num_heads, dilation_rate, segment_size, dropout=0.0, causal=False, use_xpos=False, use_rel_pos_bias=False ):
...
-
d_model
(int): The dimensionality of the model. This should match the dimension of the input to the layer. -
num_heads
(int): The number of attention heads to use in theFlashMHA
attention mechanism. -
dilation_rate
(int): The dilation rate to use when processing the input sequence. Larger values will result in fewer, but wider, attention computations. -
segment_size
(int): The size of the segments into which the input sequence is divided before dilating and computing attention. -
dropout
(float, optional): The dropout rate to apply to the attention outputs. Default is 0.0. -
causal
(bool, optional): If True, a causal mask will be applied to the attention outputs, preventing any given position from attending to future positions. Default is False. -
use_xpos
(optional): If set to True, xpos is used for positional encoding. Default: False -
use_rel_pos_bias
(optional): If set to True, relative position bias is used in the attention mechanism. Default: False
First, you need to create an instance of the DilatedAttention
class. Here is how you do it:
dilated_attn = DilatedAttention(d_model=512, num_heads=8, dilation_rate=2, segment_size=64, dropout=0.1, causal=True, use_xpos=False, use_rel_pos_bias=False)
In this example, we're creating a DilatedAttention
layer with a model dimensionality of 512, 8 attention heads, a dilation rate of 2, a segment size of 64, a dropout rate of 0.1, and causal masking enabled.
To perform a forward pass through the layer, simply call the instance as if it were a function, passing in your input tensor:
import torch
# Assume x is your input tensor with shape (batch_size, sequence_length, d_model)
x = torch.rand(16, 1000, 512).to(device)
output = dilated_attn(x)
In this example, the input tensor x
has a batch size of 16, a sequence length of 1000, and a model dimensionality of 512. The output tensor will have the same shape as the input tensor.
You can integrate the DilatedAttention
layer into a larger model just like any other PyTorch layer. For example, here's how you might use it as part of a simple transformer-like model:
class SimpleTransformer(nn.Module):
def __init__(self, d_model, num_heads, dilation_rate, segment_size, dropout):
super().__init__()
self.dilated_attn = DilatedAttention(d_model, num_heads, dilation_rate, segment_size, dropout, causal=True, use_xpos=False, use_rel_pos_bias=False)
self.fc = nn.Linear(d_model, 10) # Assume we're doing a 10-class classification task
def forward(self, x):
x = self.dilated_attn(x)
x = self.fc(x[:, 0]) # Use the first position output as the "CLS" token
return x
model = SimpleTransformer(d_model=512, num_heads=8, dilation_rate=2, segment_size=64, dropout=0.1)
In this example, we first pass the input tensor through the DilatedAttention
layer, then we pass the output of the first position through a fully-connected layer to perform a classification task.
DilatedAttention
is a neural network architecture that incorporates attention mechanisms, specifically the multi-head attention, in a dilated manner. The main idea behind this architecture is to leverage the efficient attention calculation capabilities of the FlashMHA
method, which is part of the flash_attn
module, while also providing the ability to handle longer sequences with reduced computation through dilation.
The class DilatedAttention
has the following primary components:
-
FlashMHA attention: A fast and efficient multi-head attention mechanism implemented using the
FlashMHA
method. This is the main attention computation method used in the architecture. -
Dilation: Dilating the input sequences allows the model to handle longer sequences with fewer computations, making the architecture more scalable and efficient.
-
Causal masking (optional): If the
causal
argument is set toTrue
, a causal mask is applied to the attention outputs, ensuring that each output position only depends on earlier positions in the sequence. This feature is particularly useful when dealing with sequential data where future dependencies should not be considered. -
Dropout: A dropout layer that can be configured to add regularization to the model and prevent overfitting.
The DilatedAttention
model works in the following steps:
-
Input Reshape: Reshapes the input into smaller segments based on the provided
segment_size
and then dilates it by selecting everydilation_rate
segment. -
Attention Computation: Uses
FlashMHA
to compute the attention over the dilated segments. -
Causal Masking: If
causal
is set toTrue
, a causal mask is applied to the attention output. This ensures that the output at each position in the sequence does not depend on any future positions. -
Dropout: Applies dropout to the attention outputs as a means of regularization.
-
Output Reshape: Reshapes the output to match the original sequence length, concatenating the dilated segments.
The DilatedAttention
model achieves efficiency and scalability in several ways:
-
Efficient attention calculation: The use of
FlashMHA
enables efficient and fast attention computation. -
Dilation: Dilation allows the model to handle longer sequences with reduced computation, effectively making the model more scalable.
-
Causal masking: By ensuring that each output position only depends on earlier positions in the sequence, the model becomes suitable for tasks involving sequential data.
-
Parallelization: Take advantage of the parallel processing capabilities of modern GPUs for the dilation and reshaping steps.
-
Memory optimization: Efficient memory usage could be achieved through gradient checkpointing or activation pruning.
-
Pre-computation: If some portions of the input data remain constant across multiple operations, pre-compute those portions and store the results for reuse.
-
Batch normalization: Incorporating batch normalization layers could help to speed up the learning process and improve generalization.
-
Pruning and Quantization: Pruning unnecessary connections and quantizing the model parameters can help in reducing the model's memory footprint and speed up computation without sacrificing much accuracy.
Share LongNet with your friends and colleagues who might find it useful. Simply click on the links below to share on various platforms:
Thank you for sharing!
-
Recreate the sparsification mechanism
-
Recreate the gathering mechanism
-
Implement FlashAttention2.0
-
Implement Distributed Setup
-
create the all-gather operation in the backward that becomes a reduce-scatter operation
@inproceedings{ding2023longnet,
title={LongNet: Scaling Transformers to 1,000,000,000 Tokens},
author={Ding, Jiayu and Ma, Shuming and Dong, Li and Zhang, Xingxing and Huang, Shaohan and Wang, Wenhui and Wei, Furu},
booktitle={Proceedings of the 10th International Conference on Learning Representations},
year={2023}
}