Triton implementation of GPT/LLAMA models. Objective of this project is to understand how much performance can be squeezed out if we implement full-GPT-block in one triton kernel.
Performance
triton implementation is more fast & memory efficient compared to HuggingFace Transformers implementation.
python3 bench.py
Latency
precision | HuggingFace GPT | Triton GPT |
---|---|---|
fp32 | 1800 ms | - |
tf32 | 631.35 ms | 462.63 ms |
mixed precision (fp16) | 510.80 ms | 273 ms |
fp16 | 301.92 ms | - |
time taken to process batch size - 512x300 on 1 A100 40 GB
Max Batch Size
max batch size | |
---|---|
HuggingFace GPT | 1024 |
Triton GPT | 2048 |
I considered batch sizes with power of 2 only. Both runs had seqlen=300 and mixed precision was enabled.
MFU
from gpt import compute_mfu
# fwd MFU
# HuggingFace GPT (fp16)
compute_mfu(2 * 124 * 10**6 * 512*512 / 0.302, gpu="h100")
# 21.76%
# HuggingFace GPT (mixed precision)
compute_mfu(2 * 124 * 10**6 * 512*512 / 0.510, gpu="h100")
# 12.88%
# triton (mixed precision)
compute_mfu(2 * 124 * 10**6 * 512*512 / 0.273, gpu="h100")
# 24.07%
Supported Features
- fused implementation of several components of GPT block (for eg:
dropout(wte(x) + wpe(x))
,dropout(wx + b)
,gelu(wx + b)
) - flash attention v1 algorithm
- GPT2 implementation in triton
- support for loading pre-trained weights of huggingface-gpt2
- support KV cache & sampling for inference loop
- implement back-propogation of GPT block in triton (i.e. solving the math problem)
- implement paged-attention from vLLM project in triton
- implement flash attention v2 & v3
- add kernels for LLAMA-3.1
- implement adamw in triton (with FSDP-stage2 support)
Installation
pip3 install -r requirements.txt
# `numpy<2` is hard-requirement for running on CPU
# else triton gives garbage - likely some bug in triton
Running tests
# you can run following command on CPU
TRITON_INTERPRET=1 pytest -sv test.py
# you can run following command on GPU
pytest -sv test.py