high throughput inference
msaroufim opened this issue · 3 comments
Was chatting with @Chillee about our plans in AO today and he mentioned we should be focusing on a few concrete problems like
- Demonstrate compelling perf for fp8 gemm at a variety of batch sizes.
- Demonstrate compelling perf for weight only int8 gemm at a variety of batch sizes.
- Demonstrate compelling perf for weight only intX gemm at low batch sizes.
- Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes.
We could as a baseline extend gpt-fast to work with bs=n w/o doing any kv cache management work and measure perf there. Copying feedback as is, open to discussing more and adding more details as time progresses
EDIT: gpt-fast already has a batched generation branch by Horace https://github.com/pytorch-labs/gpt-fast/tree/batched_generation
@HDCharles on the int8 work
@vkuzo on fp8
@vayuda and @jerryzh168 on intx
Would be interesting to bench against something like QoQ, which implements W4A8KV4 (int8 GEMM) using a nested quantization scheme and neat kernel-level optimizations.
Demonstrate compelling perf for fp8 gemm at a variety of batch sizes.
Note that I'm putting up a PR soon for a quick roofline estimator for float8 gemm + overhead specific to training to see for which M, K, N float8 is faster than bfloat16, it would be easiliy extendable to inference at a later time.
Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes.
While this is possible technically, I'm not sure I understand the value, would be interested to learn more.