FlashAttention
This repository provides the official implementation of FlashAttention from the following paper.
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Paper: https://arxiv.org/abs/2205.14135
IEEE Spectrum article about our submission to the MLPerf 2.0 benchmark using FlashAttention.
Usage
We've been very happy to see FlashAttention being widely adopted in such a short time after its release. This page contains a partial list of places where FlashAttention is being used.
Full model code and training script
We have released the full GPT model implementation. We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x compared to the baseline implementation from Huggingface, reaching up to 189 TFLOPs/sec per A100, equivalent to 60.6% model FLOPs utilization (we don't need any activation checkpointing).
We also include a training script to train GPT2 on Openwebtext and GPT3 on The Pile.
Triton implementation of FlashAttention
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
As Triton is a higher-level language than CUDA, it might be easier to understand and experiment with. The notations in the Triton implementation are also closer to what's used in our paper.
We also have an experimental implementation in Triton that support attention bias (e.g. ALiBi): https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py
Installation and features
Requirements:
- CUDA 11.4 and above.
- PyTorch 1.12 and above.
We recommend the Pytorch container from Nvidia, which has all the required tools to install FlashAttention.
To install:
pip install flash-attn
Alternatively you can compile from source:
python setup.py install
Interface: src/flash_attention.py
To run the benchmark against PyTorch standard attention:
PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
FlashAttention currently supports:
- Turing, Ampere, Ada, or Hopper GPUs (e.g., H100, A100, RTX 3090, T4, RTX 2080).
- fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
- Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100 or H100.
Our tentative roadmap:
[Jun 2022] Make package pip-installable[Done, thanks to lucidrains].[Jun 2022] Support SM86 GPUs (e.g., RTX 3080, 3090)[Done].[Jun 2022] Support SM75 GPUs (e.g. T4)[Done].[Jun 2022] Support bf16[Done].[Jul 2022] Implement cross-attention[Done].[Jul 2022] Support head dimension 128[Done].[Aug 2022] Fuse rotary embedding[Done].[Mar 2023] Support SM90 GPUs (H100)[Done].
How to use FlashAttention
Here's a simple example:
import torch
from flash_attn.flash_attention import FlashMHA
# Replace this with your correct GPU device
device = "cuda:0"
# Create attention layer. This is similar to torch.nn.MultiheadAttention,
# and it includes the input and output linear layers
flash_mha = FlashMHA(
embed_dim=128, # total channels (= num_heads * head_dim)
num_heads=8, # number of heads
device=device,
dtype=torch.float16,
)
# Run forward pass with dummy data
x = torch.randn(
(64, 256, 128), # (batch, seqlen, embed_dim)
device=device,
dtype=torch.float16
)
output = flash_mha(x)[0]
Alternatively, you can import the inner attention layer only (so that the input and output linear layers are not included):
from flash_attn.flash_attention import FlashAttention
# Create the nn.Module
flash_attention = FlashAttention()
Or, if you need more fine-grained control, you can import one of the lower-level
functions (this is more similar to the torch.nn.functional
style):
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
# or
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_split_func
# etc.
There are also separate Python files with various FlashAttention extensions:
# Import the triton implementation (torch.nn.functional version only)
from flash_attn.flash_attn_triton import flash_attn_func
# Import block sparse attention (nn.Module version)
from flash_attn.flash_blocksparse_attention import FlashBlocksparseMHA, FlashBlocksparseAttention
# Import block sparse attention (torch.nn.functional version)
from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
Speedup and Memory Savings
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
We currently have benchmarks for these GPUs:
A100
We display FlashAttention speedup using these parameters (similar to BERT-base):
- Batch size 8
- Head dimension 64
- 12 attention heads
Our graphs show sequence lengths between 128 and 4096 (when standard attention runs out of memory on an A100), but FlashAttention can scale up to sequence length 64K.
Speedup
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels. At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
Memory
We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. We see 10X memory savings at sequence length 2K, and 20X at 4K. As a result, FlashAttention can scale to much longer sequence lengths.
Head Dimension 128
We show speedup with head dimension 128. Here we show batch size 16 with 12 heads. Speedup is less than with the smaller head sizes, since we have to make the block size smaller in the tiling. But speedup is still significant, especially with a causal mask.
RTX 3090
For the RTX 3090, we use batch size 12 with 12 attention heads. Memory savings are the same as on an A100, so we'll only show speedup here.
We see slightly higher speedups (between 2.5-4.5x) on the GTX 3090, since memory bandwidth on the GDDR6X is lower than A100 HBM (~900 GB/s vs. ~1.5 TB/s).
T4
We again use batch size 12 with 12 attention heads.
T4 SRAM is smaller than the newer GPUs (64 KB), so we see less speedup (we need to make the block sizes smaller, so we end up doing more R/W). This matches the IO complexity analysis from section 3.2 of our paper.
T4 GPUs are commonly used for inference, so we also measure speedup on the forward pass only (note that these are not directly comparable to the graphs above):
We see speedups between 2.5x-4.5x on the forward pass.
Tests
We test that FlashAttention produces the same output and gradient as a reference implementation, up to some numerical tolerance. In particular, we check that the maximum numerical error of FlashAttention is at most twice the numerical error of a baseline implementation in Pytorch (for different head dimensions, input dtype, sequence length, causal / non-causal).
To run the tests:
pytest -q -s tests/test_flash_attn.py
When you encounter issues
This alpha release of FlashAttention contains code written for a research project to validate ideas on speeding up attention. We have tested it on several models (BERT, GPT2, ViT). However, there might still be bugs in the implementation that we hope to iron out in the next few months.
If you encounter any of these bugs, please open a respective GitHub Issue!
Acknowledgments
Our implementation uses Apex's FMHA code as a starting point.
We thank Young-Jun Ko for the in-depth explanation of his FMHA implementation and for his thoughtful answers to our questions about CUDA.
Citation
If you use this codebase, or otherwise found our work valuable, please cite:
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems},
year={2022}
}