/gpt-triton

Triton implementation of GPT/LLAMA

Primary LanguagePythonMIT LicenseMIT

Triton implementation of GPT/LLAMA models. Objective of this project is to understand how much performance can be squeezed out if we implement full-GPT-block in one triton kernel.

Performance

triton implementation is more fast & memory efficient compared to HuggingFace Transformers implementation.

python3 bench.py

Latency

precision HuggingFace GPT Triton GPT
fp32 1800 ms -
tf32 631.35 ms 462.63 ms
mixed precision (fp16) 510.80 ms 273 ms
fp16 301.92 ms -

time taken to process batch size - 512x300 on 1 A100 40 GB

Max Batch Size

max batch size
HuggingFace GPT 1024
Triton GPT 2048

I considered batch sizes with power of 2 only. Both runs had seqlen=300 and mixed precision was enabled.

MFU

from gpt import compute_mfu
# fwd MFU

# HuggingFace GPT (fp16)
compute_mfu(2 * 124 * 10**6 * 512*512 / 0.302, gpu="h100")
# 21.76%

# HuggingFace GPT (mixed precision)
compute_mfu(2 * 124 * 10**6 * 512*512 / 0.510, gpu="h100")
# 12.88%

# triton (mixed precision)
compute_mfu(2 * 124 * 10**6 * 512*512 / 0.273, gpu="h100")
# 24.07%

Supported Features

  • fused implementation of several components of GPT block (for eg: dropout(wte(x) + wpe(x)), dropout(wx + b), gelu(wx + b))
  • flash attention v1 algorithm
  • GPT2 implementation in triton
  • support for loading pre-trained weights of huggingface-gpt2
  • support KV cache & sampling for inference loop
  • implement back-propogation of GPT block in triton (i.e. solving the math problem)
  • implement paged-attention from vLLM project in triton
  • implement flash attention v2 & v3
  • add kernels for LLAMA-3.1
  • implement adamw in triton (with FSDP-stage2 support)

Installation

pip3 install -r requirements.txt
# `numpy<2` is hard-requirement for running on CPU
# else triton gives garbage - likely some bug in triton

Running tests

# you can run following command on CPU
TRITON_INTERPRET=1 pytest -sv test.py

# you can run following command on GPU
pytest -sv test.py