Explorations into Ring Attention, from Liu et al. at Berkeley AI.
It basically splits the data across the sequence dimension (instead of batch) and applies ring reduce to the processing of the tiles of the attention matrix, flash attention style.
I believe this is being used for the 1-10 million tokens for the latest Gemini. At least some form of it; the other possibility would be unpublished improvements on top of RMT.
In addition, the repository also contains the logic for Striped Attention, a follow up paper that permutes the sequence for better workload balancing for autoregressive transformers.
- A16Z Open Source AI Grant Program for the generous sponsorship, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research
$ pip install ring-attention-pytorch
import torch
from ring_attention_pytorch import RingAttention
attn = RingAttention(
dim = 512,
dim_head = 64,
heads = 8,
causal = True,
auto_shard_seq = True,
ring_attn = True,
ring_seq_size = 512
)
tokens = torch.randn(1, 1024, 512)
attended = attn(tokens)
assert attended.shape == tokens.shape
$ python assert.py
-
make it work with derived causal mask based on rank and chunk sizes
-
modify flash attention to output intermediates and figure out backwards with recompute and ring passes
-
functions for splitting the sequence evenly among ranks, either within attention function, or in the external ring transformer wrapper
-
basic test case with two processes and check for equivalent output and gradients
-
testing
- make sure key padding mask works
- make sure causal mask works
- rotary embeddings, with proper key/value offset depending on ring rank
-
striped attention
- add the permutating logic before and after transformer
- add causal masking logic - account for sub bucketing by flash attention
-
fix issue with ring attention when flash buckets > 1
-
move flash attention back to key / value column traversal on outer loop and save on ring communication
- backwards
- forwards
-
fix rotary positions for striped ring attention when flash buckets > 1
-
allow for variable ring passes per layer, for local -> global attention in ring transformer as one goes up the layers.
-
when doing ring passes, alternate between designated send and receive buffers
-
instead of max ring passes, able to specify lookback in terms of sequence length, and derive number of flash attention bucket + ring passes from that
-
ability to have ring size < world size, sharding the batch and sequence, and doing ring reduce with the correct set of ranks
-
add flash attention kernel version in the presence of cuda
- for forwards, use modified Triton flash attention forwards that outputs row sums, maxes, and exponentiated weighted sum
- for backwards, use Tri's flash attention kernels, accumulate dq, dk, dv across rings
- refactor to have naive ring+flash attention work with
(batch, seq, head, dim)
- handle key padding mask for forwards by translating mask to bias
- figure out how Tri handles key padding mask for backwards
- scale output of flash attention forwards on the last ring pass reduce
- verify backwards working in a100 runpod
- dk, dv needs to be float32, while kv needs to be float16. see if both can be cast to int before stacked and ring passed all in one go, then reinterpret back to float32 and float16
- prevent an unnecessary
tl.load
on the first ring pass - cuda backwards pass must have same dq, dk, dv as naive
-
fix naive flash attention backwards
-
validate cuda causal and striped ring attention works
-
find a machine with 8 GPUs and test with a quarter million tokens first
-
think about how to craft a special
Dataset
that shards across sequence length (take into account labels for cross entropy loss) for ring transformer training -
add ring attention to Tri's flash attention implementation. find some cuda ring reduce impl
-
batch_isend_irecv
in the presence of key padding mask needing ring exchange, but not a big priority -
figure out how to pytest distributed pytorch
-
use sdp context manager to validate when it is possible to use
ring_flash_attn_cuda
, otherwise assert out
@article{Liu2023RingAW,
title = {Ring Attention with Blockwise Transformers for Near-Infinite Context},
author = {Hao Liu and Matei Zaharia and Pieter Abbeel},
journal = {ArXiv},
year = {2023},
volume = {abs/2310.01889},
url = {https://api.semanticscholar.org/CorpusID:263608461}
}
@article{Brandon2023StripedAF,
title = {Striped Attention: Faster Ring Attention for Causal Transformers},
author = {William Brandon and Aniruddha Nrusimha and Kevin Qian and Zachary Ankner and Tian Jin and Zhiye Song and Jonathan Ragan-Kelley},
journal = {ArXiv},
year = {2023},
volume = {abs/2311.09431},
url = {https://api.semanticscholar.org/CorpusID:265220849}
}
@article{Dao2022FlashAttentionFA,
title = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
author = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
journal = {ArXiv},
year = {2022},
volume = {abs/2205.14135}
}
@article{dao2023flashattention2,
title = {Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,
author = {Dao, Tri},
year = {2023}
}
@article{Tillet2019TritonAI,
title = {Triton: an intermediate language and compiler for tiled neural network computations},
author = {Philippe Tillet and H. Kung and D. Cox},
journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
year = {2019}
}