pytorch/ao

high throughput inference

msaroufim opened this issue · 3 comments

Was chatting with @Chillee about our plans in AO today and he mentioned we should be focusing on a few concrete problems like

  1. Demonstrate compelling perf for fp8 gemm at a variety of batch sizes.
  2. Demonstrate compelling perf for weight only int8 gemm at a variety of batch sizes.
  3. Demonstrate compelling perf for weight only intX gemm at low batch sizes.
  4. Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes.

We could as a baseline extend gpt-fast to work with bs=n w/o doing any kv cache management work and measure perf there. Copying feedback as is, open to discussing more and adding more details as time progresses

EDIT: gpt-fast already has a batched generation branch by Horace https://github.com/pytorch-labs/gpt-fast/tree/batched_generation

@HDCharles on the int8 work
@vkuzo on fp8
@vayuda and @jerryzh168 on intx

@msaroufim

Would be interesting to bench against something like QoQ, which implements W4A8KV4 (int8 GEMM) using a nested quantization scheme and neat kernel-level optimizations.

Demonstrate compelling perf for fp8 gemm at a variety of batch sizes.

Note that I'm putting up a PR soon for a quick roofline estimator for float8 gemm + overhead specific to training to see for which M, K, N float8 is faster than bfloat16, it would be easiliy extendable to inference at a later time.

Demonstrate compelling perf for weight intX, activation fp8 at a variety of batch sizes.

While this is possible technically, I'm not sure I understand the value, would be interested to learn more.