/grouped_gemm

PyTorch bindings for CUTLASS grouped GEMM for rocm

Primary LanguageCudaApache License 2.0Apache-2.0

Grouped GEMM for MoE

A PyTorch Toolbox for Grouped GEMM in MoE Model Training

license


Steps for Using

pip install

pip install --verbose git+https://github.com/fanshiqing/grouped_gemm@main

Build from Source

git submodule update --init --recursive
mkdir build
cd build
cmake ..
make -j
cd ..

# GroupedGEMM ops test
python grouped_gemm/ops_test.py

# topK permute & unpermute ops test
python grouped_gemm/permute_test.py

# sinkhorn kernel test
python grouped_gemm/sinkhorn_test.py

Support Matrix

permute & unpermute

GPU Arch FP32 FP16 BF16 FP8
SM 70 Y Y . Y
SM 75 Y Y . Y
SM 80 Y Y Y Y
SM 86 Y Y Y Y
SM 89 Y Y Y Y
SM 90 Y Y Y Y

Ops Usage

permute

grouped_gemm.ops.permute(
  input_act: torch.Tensor,
  indices: torch.Tensor,
  num_out_tokens: int = 0,
  max_token_num=0: int) -> tuple

The output tuple of (torch.Tensor, torch.Tensor) that contains two tensors permuted_act and row_id_map.

  • permuted_act is the permutation of the original tensor input_act with its first dimension permuted according to indices.
  • row_id_map is the mapping table for the row indices of the input activations before and after grouped_gemm.ops.permute, which is used for the following unpermute op.

Parameters

  • input_act (torch.Tensor)
     shape = [tokens_num, hidden_size]
     The input activations with each row (token) corresponds to topK experts.

  • indices (torch.Tensor)
     shape = [tokens_num, topK_num]
     The topK expert indices for each row (token) of activations. The int32 type is recommended.

  • num_out_tokens (int)  The number of output tokens (rows) used for token drop feature.

  • max_token_num (int)
     The maximum number of tokens (rows) used for workspace pre-allocation.

unpermute

grouped_gemm.ops.unpermute(
  input_act: torch.Tensor,
  row_id_map: torch.Tensor,
  probs) -> torch.Tensor

The mirror operator of grouped_gemm.ops.permute.

Parameters

  • input_act (torch.Tensor)
     shape = [tokens_num * topK_num, hidden_size]
     The permuted activations produced by grouped_gemm.ops.permute.

  • row_id_map (torch.Tensor)
     shape = [tokens_num * topK_num]
     The mapping table for the row indices of the activations before and after grouped_gemm.ops.permute. The second output tensor of grouped_gemm.ops.permute.

  • probs (torch.Tensor)
     shape = [tokens_num, topK_num]
     Sum weights for same-origin tokens from different experts.

Example

import torch
from grouped_gemm import permute, unpermute

indices = torch.tensor([[1, 2], [0, 1], [0, 2], [1, 2]], dtype=torch.int32, device='cuda')
input_act = torch.tensor([[0,0,0,0], [1,1,1,1], [2,2,2,2], [3,3,3,3]], dtype=torch.float32, device='cuda')
probs = torch.ones_like(indices, dtype=torch.float32)
permuted_inputs, row_id_map = permute(input_act, indices)
unpermute_outputs = unpermute(permuted_inputs, row_id_map, probs)

print(row_id_map)
print(input_act)
print(permuted_inputs)
print(unpermute_outputs)

# Output
# tensor([2, 0, 1, 4, 5, 3, 6, 7], device='cuda:0', dtype=torch.int32)
# tensor([[0., 0., 0., 0.],
#         [1., 1., 1., 1.],
#         [2., 2., 2., 2.],
#         [3., 3., 3., 3.]], device='cuda:0')
# tensor([[1., 1., 1., 1.],
#         [2., 2., 2., 2.],
#         [0., 0., 0., 0.],
#         [1., 1., 1., 1.],
#         [3., 3., 3., 3.],
#         [0., 0., 0., 0.],
#         [2., 2., 2., 2.],
#         [3., 3., 3., 3.]], device='cuda:0')
# tensor([[0., 0., 0., 0.],
#         [2., 2., 2., 2.],
#         [4., 4., 4., 4.],
#         [6., 6., 6., 6.]], device='cuda:0')