[Feature Request] GEMM benchmarks and FP8 Support
jwfromm opened this issue · 8 comments
I really like the simplicity of TK and think it could be broadly applicable to kernel authoring beyond attention. Has there been any benchmarking done of pure GEMM operations? If so, an example would be very helpful for my understanding of how to use TK in a broader way.
One of Hopper's coolest new features is native support for FP8. However, there are very few kernels that support it outside of cublas. When any sort of fusion or customization is needed, the tooling ecosystem is sorely lacking. Adding FP8 support to TK could be quite useful and pair nicely with GEMM.
ugh the annoying thing about fp8 is that the transpose instructions don't work for 8-bit types -- IDK why NVIDIA went only halfway on adding 8-bit hardware support. So FP8 transposes are going to be really slow and rely on unpacking, doing 4 shfl_sync's per 8x16 core matrix, and then repacking. (Unless I'm missing something, but I read the ptx pretty carefully looking for some canonical way nvidia would want you to do it. Also maybe there's some trick to do a b16 movmatrix first and then just do 2 shfl_sync's? but it will be bad either way.) So writing fast FP8 kernels is going to have more hassle (and you will actually need to change your dram layouts) to get the full performance unless you happen to be completely bandwidth bound.
also question for those interested in fp8: do you want full h100 support (wgmma, tma, etc) or are you just planning to do edge inference on 4090s and whatnot? if people want this, I at one point wrote a doc of exactly what needs to be done. i dont really want to do it, though.
also we'll put some gemms in the examples, good feedback.
I provided a simple GEMM implementation, but a more optimized GEMM implementation requires support for ldmatrix and pipeline, which I haven't implemented yet.
@benjaminfspector I would be interested for l40s and 4090 support. Can you please share the doc?
@ethxnp https://github.com/HazyResearch/ThunderKittens/blob/fp8/fp8-todo.md
(This is on a very old branch, but I updated the todo just now so that it is correct for the current main/ branch. I would definitely fork off of main, not this branch.)
@benjaminfspector Thanks for the excellent thoughts and tips. At least in my case I am primarily interested in H100 with all the related features. For what its worth, we've found that CUTLASS does a decent job of implementing FP8 with WGMMA + TMA. I'm not fluent enough with the details to say how they handled some of the challenges you mentioned but it could be a useful reference for figuring out Nvidia's blessed way of running fp8.
That's definitely the right call, to reverse engineer however Cutlass handles it; I'm sure they do something sensible, and frankly that was our only hope of getting the WGMMA swizzling modes working to begin with. I did at one point take a look through their code for this. It looked to me like they weren't handling the transpose within CUTE -- see https://github.com/NVIDIA/cutlass/blob/main/include/cute/arch/mma_sm90.hpp#L401 (it insists on major K layout for both operands) -- so I don't think I'm missing anything on the PTX side. But perhaps they have a more clever way of handling it somewhere up the stack?
Regarding fp8
mma
for sm89
, it's handled similarly to int8
mma.m16.n8.k32
. That is to say, A
must be row-major
and B must be column- major
.
I suspect that the reason for this is that ldmatrix
is used to load from shared to registers, except now the minimum tile shape for ldmatrix.x1
is 8 x 16
and 16x32
for an ldmatrix.x4
, i.e., 8 rows of 16 fp8
s and 16 rows of 32 fp8
s, respectively.
Note that since the type is now less than 16 bits, ldmatrix.trans
will not work as in the b16
case, since the transposed load assumes a b16
sized type, so one now uses ldmatrix
to load both A
and B
matrices rather than ldmatrix
for A
and and ldmatrix.trans
for B
.
This works if B
is column major: e.g., to compute an 16x8
matrix, one would issue an ldmatrix.x4
to load A
(16 x 32
) then 2 ldmatrix.x1
along the K
axis to load 8 x 32
. (Alternatively, one could do a single x4
load for B
to compute a 16x16
tile).
You can see the required layouts mma.m16.n8,k32
for sm89
fp8
here and here. There are only 2 template specializations, both enforcing the aforementioned row-major / col-major layouts. Similarly, for b8
int
mma
, you can see here that all specializations require the same row / col layout.
Not sure how Cutlass
handles the layouts for Hopper
with additional complexities of tma
and gmma
.
Happy to work on a simple gemm example for fp8
if there's interest.