HazyResearch/ThunderKittens

[Feature Request] GEMM benchmarks and FP8 Support

jwfromm opened this issue · 8 comments

I really like the simplicity of TK and think it could be broadly applicable to kernel authoring beyond attention. Has there been any benchmarking done of pure GEMM operations? If so, an example would be very helpful for my understanding of how to use TK in a broader way.

One of Hopper's coolest new features is native support for FP8. However, there are very few kernels that support it outside of cublas. When any sort of fusion or customization is needed, the tooling ecosystem is sorely lacking. Adding FP8 support to TK could be quite useful and pair nicely with GEMM.

ugh the annoying thing about fp8 is that the transpose instructions don't work for 8-bit types -- IDK why NVIDIA went only halfway on adding 8-bit hardware support. So FP8 transposes are going to be really slow and rely on unpacking, doing 4 shfl_sync's per 8x16 core matrix, and then repacking. (Unless I'm missing something, but I read the ptx pretty carefully looking for some canonical way nvidia would want you to do it. Also maybe there's some trick to do a b16 movmatrix first and then just do 2 shfl_sync's? but it will be bad either way.) So writing fast FP8 kernels is going to have more hassle (and you will actually need to change your dram layouts) to get the full performance unless you happen to be completely bandwidth bound.

also question for those interested in fp8: do you want full h100 support (wgmma, tma, etc) or are you just planning to do edge inference on 4090s and whatnot? if people want this, I at one point wrote a doc of exactly what needs to be done. i dont really want to do it, though.

also we'll put some gemms in the examples, good feedback.

I provided a simple GEMM implementation, but a more optimized GEMM implementation requires support for ldmatrix and pipeline, which I haven't implemented yet.

@benjaminfspector I would be interested for l40s and 4090 support. Can you please share the doc?

@ethxnp https://github.com/HazyResearch/ThunderKittens/blob/fp8/fp8-todo.md

(This is on a very old branch, but I updated the todo just now so that it is correct for the current main/ branch. I would definitely fork off of main, not this branch.)

@benjaminfspector Thanks for the excellent thoughts and tips. At least in my case I am primarily interested in H100 with all the related features. For what its worth, we've found that CUTLASS does a decent job of implementing FP8 with WGMMA + TMA. I'm not fluent enough with the details to say how they handled some of the challenges you mentioned but it could be a useful reference for figuring out Nvidia's blessed way of running fp8.

That's definitely the right call, to reverse engineer however Cutlass handles it; I'm sure they do something sensible, and frankly that was our only hope of getting the WGMMA swizzling modes working to begin with. I did at one point take a look through their code for this. It looked to me like they weren't handling the transpose within CUTE -- see https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm90.hpp#L401 (it insists on major K layout for both operands) -- so I don't think I'm missing anything on the PTX side. But perhaps they have a more clever way of handling it somewhere up the stack?

@jwfromm @benjaminfspector

Regarding fp8 mma for sm89, it's handled similarly to int8 mma.m16.n8.k32. That is to say, A must be row-major and B must be column- major.

I suspect that the reason for this is that ldmatrix is used to load from shared to registers, except now the minimum tile shape for ldmatrix.x1 is 8 x 16 and 16x32 for an ldmatrix.x4, i.e., 8 rows of 16 fp8s and 16 rows of 32 fp8s, respectively.

Note that since the type is now less than 16 bits, ldmatrix.trans will not work as in the b16 case, since the transposed load assumes a b16 sized type, so one now uses ldmatrix to load both A and B matrices rather than ldmatrix for A and and ldmatrix.trans for B.

This works if B is column major: e.g., to compute an 16x8 matrix, one would issue an ldmatrix.x4 to load A (16 x 32) then 2 ldmatrix.x1 along the K axis to load 8 x 32. (Alternatively, one could do a single x4 load for B to compute a 16x16 tile).

You can see the required layouts mma.m16.n8,k32 for sm89 fp8 here and here. There are only 2 template specializations, both enforcing the aforementioned row-major / col-major layouts. Similarly, for b8 int mma, you can see here that all specializations require the same row / col layout.

Not sure how Cutlass handles the layouts for Hopper with additional complexities of tma and gmma.

Happy to work on a simple gemm example for fp8 if there's interest.