[QST] Skinny fp8 gemm on H100
Closed this issue · 13 comments
I'm trying to get the best perf of a 'skinny' gemm on fp8 matrices on H100. My gemm MxNxK is 16x14336x4096 or 32x14336x4096 or 64x14336x4096 . This should obviously be heavily memory-bound.
I'm running cutlass profiler (I vary --m=16/32/64):
cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_ -DCUTLASS_LIBRARY_IGNORE_KERNELS=bf16,f16
./tools/profiler/cutlass_profiler --output=/tmp/cutlass.out --operation=Gemm --m=16 --n=14336 --k=4096
but the the best memory b/w I get is ~2.2TB/sec, which nsight compute rates at ~68% of memory b/w.
I've also tried to write kernel by hand following fp8 example and play w/ different tiles to match https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shape , but I'm always peaking at ~70% of memory b/w.
Is this the best number I can get? Can I get closer to 80% or more?
For the input shape of 64x14336x4096, you should not be bandwidth bound. M = 64 is a size that is natively supported by the Hopper tensor core, and your K dimension is large enough for the problem shape to be compute bound in the roofline. For any problem M shape smaller than 64, you are unfortunately going to be wasting a lot of FLOPs as the smallest supported shape for the Hopper fp8 tensor core is 64 along the M dimension.
Can you share which kernel nets you the best performance? How many of the candidate fp8 kernels in the profiler are you trying? You should be able to get pretty solid performance for at least the M=64 problem shape.
These seem to be the best:
cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_e4m3_f32_f32_e5m2_64x128x128_1x2x1_0_tnn_align16_warpspecialized_pingpong_fp8_fastaccum_epi_nosmem
cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_e4m3_f32_f32_e5m2_64x128x128_1x2x1_0_tnn_align16_warpspecialized_pingpong_fp8_fastaccum_epi_tma
Here's the full list: https://gist.github.com/divchenko/a32d50e5771190291a55efec9aaa814d
Regarding flops, I see a quite low number even on m=64: 286 TFLOP/sec, which has ~7x room to grow. nsight compute profiler corroborates w/ this finding.
It's fine to waste compute as long as I can get better memory b/w.
Can you try a larger cluster size of 1x4x1 for the two kernel recipes that are performing best? And can you also try running the same problem with cuBLAS and see what perf you get there?
Tiny bit better
2241.57 GiB/s
2238.46 GiB/s
hmm... seems like margin of error delta to me. Is this for the larger cluster size or the cuBLAS version? and which number corresponds to which kernel?
Larger cluster size. I will try cublas later today.
First and second kernels respectively, but yeah, seems like a margin of error
Tried cublas, it's a bit worse, ~65% of memory b/w. Here's the kernel it chooses : sm90_xmma_gemm_e4m3e4m3_e4m3f32_f32_tn_n_tilesize128x64x128_warpgroupsize2x1x1_algo2_execute_segment_k_off_kernel__5x_cublas
It's basically based on their example https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLASLt/LtFp8Matmul/sample_cublasLt_LtFp8Matmul.cu
Firstly, I believe 64x14336x4096 is expected to be memory bound, since the flops/byte is too low. Also based on the comment above I just tried to run this specific test on latest CUTLASS 3.2.1 + CTK 12.2.2 :
./cutlass_profiler --kernels=cutlass3x_sm90_tensorop_s64x128x32gemm_e4m3_e5m2_f32_f32_e4m3_64x128x128_1x2x1_0_tnn_align16_warpspecialized_pingpong_fp8_fastaccum_epi_tma --m=64 --n=14336 --k=4096 --profiling-iterations=10000
And I see :
Bytes: 62652416 bytes
FLOPs: 7518027776 flops
FLOPs/Byte: 119
Runtime: 0.0231348 ms
Memory: 2539.16 GiB/s
Math: 324966 GFLOP/s
Which is a bit better than than what you see. Some things to try :
- Upgrade to latest i.e CTK 12.2.2
- Try altering this line to :
CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_256B;
@divchenko is your issue resolved?
resolved