[QST] Hopper mixed precision gemm always worse than FP8
divchenko opened this issue · 8 comments
I'm doing A 4 bit x B fp16 matmul w/ large A and small B. I expect it to beat fp8 matmul (it should be memory-bound).
In reality, it seems to be always worse.
Example:
Kernel code is here: https://gist.github.com/divchenko/9b02f40ae109e8dc8549afbde059d32e
it's called from python:
import torch
import cuscratch
g = 64
m = 3584
n = 16
k = 8192
scale_k = (k + g - 1) // g
s = torch.ones((m, scale_k), dtype=torch.half, device="cuda")
a = torch.ones((m, (k + 1) // 2), dtype=torch.int8, device="cuda")
b = torch.ones((n, k), dtype=torch.half, device="cuda")
d = torch.zeros((n, m), dtype=torch.half, device="cuda")
cuscratch.matmul_mixed(a, b.t(), d.t(), s, k, g)
The best perf I can get is using streamk scheduler (k is large indeed). But it's still very low on memory b/w (~20%).
Persistent tile scheduler is way worse for both TMA and TMACooperative kernel schedulers.
Fp8 implementation can reach ~60% of memory b/w and hence is faster although it reads ~2x more bytes.
Am I missing anything? Thank you!
Could you share more nfo on what exact c++ kernel is being picked in both cases ? You may have to pick a custom tile size instead of what the builder provides by default. The default ones are more optimized for compute bound cases.
@IonThruster full code is here. I've played w/ tiles. This is the best config.
#include <ATen/ATen.h>
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/activation.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <pybind11/operators.h>
#include <torch/extension.h>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/epilogue/collective/default_epilogue.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/dispatch_policy.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
namespace cuscratch {
namespace {
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::stringstream ss; \
ss << "Got cutlass error: " << cutlassGetStatusString(error) \
<< " at: " << __LINE__ << std::endl; \
throw std::runtime_error(ss.str()); \
} \
}
void matmul_mixed(torch::Tensor tensor_a, torch::Tensor tensor_b,
torch::Tensor tensor_d, torch::Tensor tensor_scale, int64_t k,
int64_t group_size) {
using MmaType = cutlass::half_t;
using QuantType = cutlass::int4b_t;
constexpr int TileShapeK = (128 * 8) / cutlass::sizeof_bits<MmaType>::value;
// A matrix configuration
using ElementA = QuantType;
using LayoutA = cutlass::layout::RowMajor;
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
// B matrix configuration
using ElementB = MmaType;
using LayoutB = cutlass::layout::ColumnMajor;
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementZero = cutlass::half_t;
using ElementScale = cutlass::half_t;
using LayoutScale = cutlass::layout::RowMajor;
// C/D matrix configuration
using ElementD = cutlass::half_t;
using LayoutD = cutlass::layout::ColumnMajor;
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
using ElementC = void;
using LayoutC = LayoutD; // Layout type for C and D matrix operands
constexpr int AlignmentC = AlignmentD;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::Int<TileShapeK>>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape,
ClusterShape, EpilogueTileType, ElementAccumulator,
ElementAccumulator, ElementC, LayoutC, AlignmentC, ElementD, LayoutD,
AlignmentD, EpilogueSchedule>::CollectiveOp;
using StageCount = cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>;
// The Scale information must get paired with the operand A that will be
// scaled.
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, cute::tuple<ElementA, ElementScale>, LayoutA,
AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator,
TileShape, ClusterShape, StageCount, KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
cute::Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::StreamKScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
/// Initialization
typename GemmKernel::StrideA stride_a{};
cute::get<0>(stride_a) = static_cast<int>(k);
typename GemmKernel::StrideB stride_b{};
cute::get<0>(stride_b) = static_cast<int>(tensor_b.stride(1));
typename Gemm::GemmKernel::StrideC stride_c{};
typename Gemm::GemmKernel::StrideD stride_d{};
cute::get<1>(stride_d) = static_cast<int>(tensor_d.stride(1));
// Scale and Zero share a stride since the layout and shapes must be the same.
using StrideS = typename CollectiveMainloop::StrideScale;
StrideS stride_s;
// Data
auto data_a = reinterpret_cast<const cutlass::int4b_t *>(tensor_a.data_ptr());
auto data_b = reinterpret_cast<const cutlass::half_t *>(tensor_b.data_ptr());
auto data_c = nullptr;
auto data_d = reinterpret_cast<const cutlass::half_t *>(tensor_d.data_ptr());
auto data_scale =
reinterpret_cast<const cutlass::half_t *>(tensor_scale.data_ptr());
typename Gemm::Arguments args;
args.mode = cutlass::gemm::GemmUniversalMode::kGemm;
args.problem_shape = {static_cast<int>(tensor_a.size(0)),
static_cast<int>(tensor_b.size(1)), static_cast<int>(k),
1};
args.mainloop = {data_a,
stride_a,
data_b,
stride_b,
data_scale,
stride_s,
static_cast<int>(group_size),
nullptr,
4 /*mma_promotion_interval*/};
args.epilogue = {{1, 0} /*alpha, beta*/, data_c, stride_c, data_d, stride_d};
Gemm gemm;
auto ws_size = static_cast<int64_t>(Gemm::get_workspace_size(args));
auto ws_tensor = at::empty({ws_size}, at::TensorOptions()
.dtype(at::ScalarType::Byte)
.device(tensor_d.device())
.requires_grad(false));
CUTLASS_CHECK(gemm.can_implement(args));
CUTLASS_CHECK(gemm.initialize(args, ws_tensor.data_ptr()));
CUTLASS_CHECK(gemm.run());
}
} // namespace
} // namespace cuscratch
PYBIND11_MODULE(cuscratch, m) { m.def("matmul_mixed", cuscratch::matmul_mixed); }
@IonThruster for fp8 version, you can just look at my old post #1139
@divchenko This behavior is expected with the current implementation. I not done a deep dive into the performance, but I have a theory that may explain the behavior you observe.
If we take a compute bound case, we typically have a MMA tile of MxNxK
= 64x256x32
, which means A's tile is 64x32
and B's tile is 256x32
. We must convert 64x32
elements of A from INT4 to FP8, but in the compute bound case, we can hide that latency behind loading the large B matrix from smem, and the big tensor core instruction. This is because the Hopper TCs are asynchronous, so while we do MMA for stage k, we can be converting the data for stage k+1.
In the memory bound case, we have much smaller tiles. In your example, it is 64x16x32
. It means A's tile size is still 64x32
but B is way smaller at 16x32
. The amount of A data we must convert is exactly the same as before, but we can no longer hide this latency behind a big tensor core instruction. I think the extra exposed latency is causing the slowdown in the memory bound case
My theory is that the conversion cost is exposed in the memory bound case.
DISCLAIMER: I don't have data supporting what I've said above. It could be completely wrong, but it is just a hunch :)
Thanks @rawnhenry . The memory-bound case for fp8 (where I have 64x16x256 tiles) actually works quite well reaching closed to 60% memory b/w. It's the mixed precision case w/ tile 128x16x64 (k tile is restricted to be at most 64 == scaling group size), which doesn't work well.
Two options I see:
- If I don't use stream-k tile scheduler then my occupancy is same as in fp8 case (~50% of grid, but seems enough to saturate HBM), but looks like because of small-ish K tile (64 as compared to 256), as you mentioned, the latency is not hidden well.
- If I use stream-k tile scheduler then occupancy is full, but, again, tiles are quite small and likely conversions and k tile streaming are not hidden well.
This issue has been labeled inactive-30d
due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d
if there is no activity in the next 60 days.
This issue has been labeled inactive-90d
due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.