[RFC]ย `segment` Matrix-Multiplication
rusty1s opened this issue ยท 2 comments
๐ The feature, motivation and pitch
In heterogeneous graph, we often want to apply a Linear
transformation across node types using different weight matrices, i.e.:
x = ... # a node feature matrix of shape [num_nodes, in_channels]
weight = ... # a weight tensor of shape [num_node_types, in_channels, out_channels]
ptr = torch.tensor([0, ..., x.size(0)]) # boundaries of node types in x, len(ptr) == num_node_types + 1
out = segment_mm(x, ptr, weight, bias=None)
The computation is performed as follows (at best fully parallelized):
out[0:10] = x[0:10] @ weight[0]
out[10:15] = x[10:15] @ weight[1]
out[15:30] = x[15:30] @ weight[2]
The underlying implementation should ideally be way faster than manually iterating over each node type and applying matrix multiplications sequentially:
for i, (start, end) in enumerate(zip(ptr[:-1], ptr[1:])):
x[start:end] @ weight[i]
and should be (hopefully) similar in performance when assuming equally-sized "segments":
x = x.view(num_node_types, num_nodes_per_node_type, in_channels)
x @ weight
cc @pyg-team/nvidia-team
After discussing internally, I think I have the perfect solution for this ๐
https://github.com/NVIDIA/cutlass/blob/master/examples/24_gemm_grouped/gemm_grouped.cu
It allows
This example profiles the performance of a 'grouped' GEMM kernel. This is similar to batched GEMM in that multiple, independent GEMMs are computed by one grid launch. It differs in that each 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored in device Global Memory and loaded by the kernel.