pytorch/ao

[RFC] Which low bit CUDA kernels should we merge or write?

msaroufim opened this issue ยท 11 comments

Here is my understanding of the existing state of things and what I think we should be doing to make our lower-bit kernels more performant at both small and larger batch sizes. I'm making this an RFC because I'm curious whether I'm paying attention to the wrong things so if you disagree with any of the below please comment!

First a quick survey of libraries

Survey of existing solutions

Interestingly enough none of the below solutions package their libraries into a package and instead encourage users to copy-paste their code and cite it. It's common to make these libraries to be headers only to make integrations easier.

And we thankfully do have the machinery to support CUDA kernels on multiple different kinds of versions with minimal headache thanks to our custom CUDA extension support https://github.com/pytorch/ao/tree/main/torchao/csrc

So it's easy to merge kernels but which ones should we actually merge?

Marlin

This is the kernel of choice in VLLM arguably the most popular inference provider on the market, they have fp16xint4 kernels that work for smaller batch sizes but larger than tinygemm and competitors and the kernels don't seem particularly affected by power limitation on GPUs, something that has bit us in the past when running internal performance benchmarks.

There's also a 2:4 sparse variant of the above kernel which we're already working on upstreaming #621 yet I'm not sure right now whether we should look to merge both kernels or just the sparse one.

Regardless the https://github.com/IST-DASLab/marlin lab does excellent work consistently and is worth following for us

tinygemm

tinygemm isn't a full library in core but it's an op and it's the speediest thing we've found for int4 weight-only quantization (w4a16) so far torch.ops.aten._weight_int4pack_mm. One of the challenges though is because of how fast it is it becomes a hammer and all our performance problems become nails whereas if we could easily accelerate other dtypes we might not rely on it so much

CUTLASS

This work leverages Universal Gemm operator in CUTLASS NVIDIA/cutlass#1549 - no bit is packing since CUTLASS supports a type for cutlass::int4b_t

There are also some open PRs in CUTLASS for signed and unsigned int4/int8 multiplication with activations in fp16 NVIDIA/cutlass#1413 by @alexsamardzic

Perhaps the main recurring con that comes up with CUTLASS is that it's hard to learn but it generally is one of the best perf targets considering it's more vertically integrated within the NVIDIA stack. And well maybe it's not hard, maybe it's a bit of a skill issue on my end.

gemlite

This is a more recent project but it offers GEMV acceleration https://mobiusml.github.io/gemlite_blogpost/ by @mobicham

The core idea is well explained in https://github.com/Bruce-Lee-LY/cuda_hgemv#optimization-method where they walk through naive implementations to ones efficiently using shared memory and warp scheduling

GEMV kernels are inherently solving a more restricted problem which is bs=1 inference a la gpt-fast

However, despite being limited to batch size 1, gemlite is quite expressive in that allows arbitrary weight dtype. If you look at their function definition gemv_A16fWnO16f_int32packing you can read that _fp16 x n-bit as 32-bit packed, mixed fp16 accumulation

The code is quite short and restricted to very few files so quite easy to releverage.

bitblas

https://github.com/microsoft/BitBLAS

This is the only repo with a pip package so packaging it doesn't make as much sense although we could explore using it as an optional backend in ao in cases when we don't have the right kernel. Their support matrix is probably the most comprehensive out of any repo in this list https://github.com/microsoft/BitBLAS#support-matrix

Suggested next steps

Merge the obviously useful kernels

The sort of obvious next steps to match the current state of things are

  1. Merging and packaging Marlin Kernels because we don't support int4 at medium batches and we dont have a good story for fast sparsity
  2. Merging and packaging CUTLASS kernels because they are very fast and are GEMM kernels and not purely GEMV meaning they will help for larger batch sizes something where we don't do super well yet and has been a recurring ask for some outside partners dealing with high throughput inference

Considering both of the above work let us work with larger batch sizes than 1 and are an industry standard where people have been frustrated with the installation experience.

Write the non-obvious kernels

For the non-obvious kernels, they haven't been written yet so our strategy typically has been

  1. Cheat by using torch.compile() with clever bitpacking as a baseline
  2. Run end-to-end benchmarks against the best options on the market. Not possible considering a lot of these kernels don't exist
  3. Run speed of light analysis using the new profiler by @jeromeku #690

End to end benchmarks are certainly helpful but considering here we're talking about kernels we'd also need to run microbenchmarks on various shapes as @jerryzh168 suggests

For bs=1 get better performance for dtypes smaller than 4

gemlite is a nice educational library supporting gemv for a variety of dtypes, so leveraging it not just for end-to-end performance benchmarks but also speed-of-light calculations to help us understand a bit better the gaps for bs=1 inference. The idea here is to ensure that performance is great for a variety of intX as opposed to overfitting to 4 just because we have tinygemm

@vayuda has already led some early work here by doing bitpacks with torch.compile so we need to start baselining more heavily

for bs=n inference start writing new kernels since they don't exist

For H100+

The biggest theme here is that instead of relying on fp16 as the activation dtype we can instead rely on fp8

Some of this work was already mentioned here #663 but we'll add more detail

  1. Compelling perf for fp8 gemm at a variety of batch sizes which is work started by @drisspg and @jainapurva
  2. Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes. In particular, this ask came from the team at Neural Magic directly

For A100

For A100 our options are a bit more obvious where we should be showing compelling dynamic quantization (quantize the activations to int8) performance on larger batch sizes. gpt-fast has already been extended to support larger batch sizes https://github.com/pytorch-labs/gpt-fast/tree/batched_generation

For this work we'd focus on int8 dynamic quantization and then work our way down from there.

Related work

  • LUT-GEMM https://arxiv.org/abs/2206.09557 uses lookup tables for the weights instead of having to dequantize them
  • Atom low bit quantization for efficient and accurate LLM serving https://arxiv.org/abs/2310.19102 - in particular they implemented their kernels for W8A8 and W4A16 but since we already have tinygemm this is not super relevant
  • Flash Infer (everyone else is citing this work) https://github.com/flashinfer-ai/flashinfer a kernel library for inference. bs=1 and n kernels for prefill, decode and append kernels on different kv cache formats including pagged, ragged and page table. compressed and quantized kv cahe.
  • DeepGEMM: https://arxiv.org/abs/2304.09049 - 2 bit matrix multiplication is represented as a lookup table and then 4 of these values are packed into an 8-bit vector register. Their benchmarks are on x86 and they benchmark vs QNNPACK - this is not CUDA specific
  • https://github.com/google/gemmlowp this is an older project which no longer seems maintained and is primarily meant to accelerate x86 and arm

Thanks for the survey summary @msaroufim! This is very helpful for understanding what kernels we might be interested to integrate.

I think one assumption here is that single kernel performance with certain shape (M, N, K) for linear is a proxy for e2e performance in model. However, I'm still unclear if this is true yet, for example, we know that llama2 int8wo and int4wo gives speedup over bfloat16: https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks, today I printed the recorded shapes from autoquant and run a microbenchmark for the single linears with corresponding (M, N, K) sizes in llama2 model: #695, but it looks like all the shapes for all the quant method are slower than bfloat16, if our microbenchmarking is done correctly, then it means single kernel perf on a given shape may not be a good proxy for model level performance, and we may need to target to optimize specific models instead of just kernels.

But I do want to see more data on microbenchmarking v.s. e2e model level benchmarking and try to understand if microbenchmarking kernel perf results can be a good proxy for model level perf results. cc @HDCharles wondering if you have data here around this point when developing autoquant.

If the assumption is not true, then it means as part of deciding what kernels we want to merge, we'd also need to say what are the models and even runtime environment (execution engines) we are optimizing for.

My feeling is it's very much a do both kind of exercise and microbenchmarks can more quickly and reliably find bad kernels and IMHO should be the main criteria to merge work.

With e2e benchmarks the tricky part is we might not know if the kernel we have is useful on an other model we haven't tried out yet. But e2e benchmarks would be the main criteria to decide whether to blog about something

  • Last time I checked, Marlin only supports symmetric quantization, torchao xdtype implements asymmetric quantization (zero-point), so that's actually an issue, it would need adding zero-point support. The code is very difficult to follow imo.
  • For medium and larger batch-sizes, why not just dequantize() and call torch.matmul, where the dequantize() step is implemented in CUDA or via torch.compile. I think jerome already implemented the CUDA dequant step which is compatible with tinygemm. We can make a trick to double quant the weights, so when they are dequantized() we get int8 not float16, so we can use torch.matmul with int8 weights instead of fp16 which is faster.
  • There are some Triton kernels for asymmetric quantization for 4-bit, 3-bit and 2-bit available here. Last time I tried them they were pretty slow. But they could be maybe re-used to implement faster versions.
  • Batched Gemlite with larger batches (>=16) would need re-implementation with wmma to efficiently use tensorcores
  • I read that there are performance issues with CUTLASS when used to implemetned asymmetric quantization, here's a gptq example but apparently it's pretty slow.
    But they made many updates since 3.2 so maybe some of the performance issues where resolved. I still couldn't find a good asym quantization example with CUTLASS

There are also some open PRs in CUTLASS for signed and unsigned int4/int8 multiplication with activations in fp16 NVIDIA/cutlass#1413 by @alexsamardzic

@msaroufim #1413 is S8 x S4 and S4 x S8 support. #1190 is for [B]F16x S4 and S4 x [B]F16 on Ampere. For Hopper, CUTLASS has mixed-input support README.md here.

There are also some open PRs in CUTLASS for signed and unsigned int4/int8 multiplication with activations in fp16 NVIDIA/cutlass#1413 by @alexsamardzic

@msaroufim #1413 is S8 x S4 and S4 x S8 support. #1190 is for [B]F16x S4 and S4 x [B]F16 on Ampere. For Hopper, CUTLASS has mixed-input support README.md here.

Does this work with bfp16/fp16 accumulation as well or just fp8?

  • Last time I checked, Marlin only supports symmetric quantization, torchao xdtype implements asymmetric quantization (zero-point), so that's actually an issue, it would need adding zero-point support. The code is very difficult to follow imo.

@mobicham torchao also supports symmetric quant, it should also be easy to support no zero_point use case as well by adding a new layout type for affine quantized tensor I think

@jerryzh168 the quality tends to be worse with symmetric quantization compared to asymmetric. Much of the quality in linear quantization actually comes from the zero-point not the scaling factor. I actually reported this issue here IST-DASLab/marlin#5 (comment)

Hi all. I highly recommend the gemm in TurboMind, which implements AWQ, GPTQ, W8A16(INT8, FP8), and is currently the fastest open-source implementation. At small batch sizes, it is several times faster than cuBLAS. It's also faster than the Marlin used in vLLM.

https://github.com/InternLM/lmdeploy/tree/main/src/turbomind/kernels/gemm

f16*u4g128 vs cublasGemmEx f16*f16, both using HMMA + f32 accumulator, on 32 weight matrices from 8 models range from 7B to 72B

image

  • sm90 features are not used yet
  • sm70 tensor core and fp16 have a shared pipeline

cc @lzhangzz @merrymercy

Recently, we have plans to extract this part into a separate library for easier integration with other projects. I am also very much looking forward to the performance after being integrated into SGLang. If interested, we can further discuss in depth. Cheers!

@zhyncs is the gemm in turbomind something you'd be interested in contributing to ao? One nice thing about marlin as an example is they have most of their code in a single file and they encourage people to copy paste and package up their kernels

@zhyncs is the gemm in turbomind something you'd be interested in contributing to ao? One nice thing about marlin as an example is they have most of their code in a single file and they encourage people to copy paste and package up their kernels

Ok

@zhyncs is the gemm in turbomind something you'd be interested in contributing to ao? One nice thing about marlin as an example is they have most of their code in a single file and they encourage people to copy paste and package up their kernels

I'll discuss the integrated technical solution with @lzhangzz, aiming to finish the integration asap. We're really excited about it! cc @lvhan028