pyg-team/pyg-lib

[RFC]ย `segment` Matrix-Multiplication

rusty1s opened this issue ยท 2 comments

๐Ÿš€ The feature, motivation and pitch

In heterogeneous graph, we often want to apply a Linear transformation across node types using different weight matrices, i.e.:

x = ...  # a node feature matrix of shape [num_nodes, in_channels]
weight = ...  # a weight tensor of shape [num_node_types, in_channels, out_channels]
ptr = torch.tensor([0, ..., x.size(0)])  # boundaries of node types in x, len(ptr) == num_node_types + 1

out = segment_mm(x, ptr, weight, bias=None)

The computation is performed as follows (at best fully parallelized):

out[0:10] = x[0:10] @ weight[0]
out[10:15] = x[10:15] @ weight[1]
out[15:30] = x[15:30] @ weight[2]

The underlying implementation should ideally be way faster than manually iterating over each node type and applying matrix multiplications sequentially:

for i, (start, end) in enumerate(zip(ptr[:-1], ptr[1:])):
    x[start:end] @ weight[i]

and should be (hopefully) similar in performance when assuming equally-sized "segments":

x = x.view(num_node_types, num_nodes_per_node_type, in_channels)
x @ weight

cc @pyg-team/nvidia-team

After discussing internally, I think I have the perfect solution for this ๐Ÿ™‚
https://github.com/NVIDIA/cutlass/blob/master/examples/24_gemm_grouped/gemm_grouped.cu
It allows

This example profiles the performance of a 'grouped' GEMM kernel. This is similar to batched GEMM in that multiple, independent GEMMs are computed by one grid launch. It differs in that each 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored in device Global Memory and loaded by the kernel.