Experimental project, implementation of forward multihead attention pass in single CUDA kernel, without additional VRAM usage for scores calculation.
Attention mechanism is widely used in contemprorary large language models (BERT, GPT, LLaMA, etc) as well as in some computer vision networks (like ViT and similar). It is a main part of the transformer architecture, in which almost half of the layers are attention layers. So optimizing one could give an opportunity to make large transformers faster (and reduce heat and CO2 emission). Usualy we calculate multihead attention as follows:
Here
-
$L$ is sequence length (or set size),$F$ is number of features,$R$ is number of heads in multihead attention,$H$ is number of features in the single head;$H \times R = F$ -
$X \in \mathbb{R}^{L \times F}$ — the input matrix; -
$W_q, W_k, W_v, W_o \in \mathbb{R}^{F \times F}$ — query, keys, values and output weight matrices -
$Q_h, K_h, V_h, Y_h \in \mathbb{R}^{L \times H}$ — submatrices of$Q, K, V, Y$ , so-called heads -
$S_h \in \mathbb{R}^{L \times L}$ — attention scores for the single head -
$Y \in \mathbb{R}^{L \times F}$ — concatenated heads matrix -
$Z \in \mathbb{R}^{L \times F}$ — output matrix
Time complexity of this operation is
Memory complexity is
Here we can already see the problem: despite the fact the features size
For the big values of
So the idea is to make fused kernel for the intermediate operation
Most tricky part here is to calculate partial softmax function without storing all the scores. We can describe the whole
here
So we can calculate attention in stream-like manner (without storing any calculations along axis
We can also represent a numerator in that manner, if we select
which is up to indices resembles the numerator in the formula above.
Suppose we have 2 sets of numbers:
Using this idea we can calculate numericaly stable representations of numerator/denominator of the softmax fracture on the small chunks of data (which can be stored in fast on-chip shared memory) and then aggregate them into the full formula, without storing intermediate results. Also we can avoid calculating
Impemented function can be used as a torch extention or can be added to CUDA/C++ program directly, by including include/fused_attn.cuh
file into the project.
Minimal dependencies are pytorch and compatible CUDA Toolkit (for C++ projects only CUDA is needed).
To install it as torch extention use:
$ python setup.py install --user
Then you can use the extention in following way:
import torch # should import torch first
from fused_attn import attention_forward
head_dim = 128 # head dim should be a power of 2 and between 16 and 128
chunk_size = 128 # chunk size should also be a power of 2 greater than 16 but less than 2*head_dim,
# also if head_dim=128, chunk_size=16 is prohibited (due to the implementation)
# I believe the best choice is to use chunk_size == head_dim
q, k, v = ... # Tensors should be of shape (batch_size, sequence_len, feature_size)
# sequence_len should be divisible by chunk_size (if not then should be padded with zeroes), feature_size - by head_dim
m = ... # optional mask of size (batch_size, sequence_len, sequence_len) filled with zeroes if
# query-key pair should not be muted and -inf else
output = attention_forward(head_dim, chunk_size, q, k, v, m)
There also a CMake project which just builds a simple test that everything is working. You can build and run it with
$ mkdir build
$ cd build
$ cmake ..
$ make
$ ./test
Following implementation was tested against the pytorch naive impementation, and looks to work correctly (despite the differences in ~1% which should be just half precision loss). You can run tests with the following command:
$ python -m unittest
Also the algorithm was tested in perfomance on the simple benchmark on the NVIDIA GeForce RTX 3050 Laptop on the random tensors with the following parameters:
batch_size=4
sequence_len=2048
feature_size=5120
head_dim=128
chunk_size=128
num_heads=40
The results of the benchmark are in the table below
Algorithm | Time per batch | Coef | Additional memory3 |
---|---|---|---|
Naive attention | 83.4 ms | 1.00 | 1280 Mb |
Fused attention | 58.6 ms | 0.70 | 0 Mb |
This project was done j4f, in a several weekends. The idea just came to my mind and I was haunted by an obsessive thought to try to implement it and to measure how much perfomance will be. Now, when my interest has been satisfied for a while, I am not sure, if I will continue this little project but if somebody (includes me) will be interested in future improvemns, here is check list what could also be done:
- implement fused forward layer
- run benchmarks on the different gpu archetictures, check the perfomance gain on the high-end gpu
- create a wrapper for this layer (add input/output linear layers) in pytorch
- implement a fused input (linear qkv layers + optional rotational encodings for queries and keys) and fused output (linear output layer + residual connection + normalization) to increase the perfomance of the whole attention block
- save scores tensor in fused layer for backward pass. It will be slower but allow us to obtain a trainable attention which is yet little bit faster then naive approach.
- implement fused backward layer (without storing scores), it also would be slower, because we need to recalculate score matrix in the backward pass, but will save us a lot of memory on training (so we could increase the batch size for example).
- implement the same layers but for quantized inputs/weights to speedup the whole network, using 8bit or 4bit quantization, and fit big LLMs in small GPUs at inference.
- build a full-sized transformer using this layer, check perfomance gain on real task.
- research another approaches to make NNs smaller and faster, and capable of running on the teapot, I know I am not the only one who want to talk to my teapot...
Footnotes
-
For example in facebook's LLaMA-7B
L
(context length) is usualy equal to 2048 andR
(number of heads) is 40 and this is a smallest model in the LLaMA family. ↩ -
Typicaly head dim
H
is one of $64$ or $128$. ↩ -
Memory used in attention function calculation to store score tensor if using half precision, qkv and output tensors are not considered there. ↩