- demo for GEMM+AllReduce
- predictive search for wave grouping
- multi-node example
- demo for GEMM+ReduceScatter, GEMM+AlltoAll
- more platforms (e.g., hopper GPU)
- end2end example
The main dependency is NCCL, which FlashOverlap uses for communication. It is convenient to download from the official website. The code has been tested with v2.18.3 and v2.19.3.
Another dependency is CUTLASS, which is included as submodule. Note that the code has been tested with v3.6.0 and v3.9.0, but fails with v3.4.0. We assume CUTLASS>=v3.6.0 works fine.
The code only supports sm_80, sm_86, sm_89 now, and the evaluation enviroments include NVIDIA RTX 3090, RTX 4090, A800, and A100 GPUs. The versions of CUDA Toolkit include CUDA 12.1, 12.2.
First, pull the repo:
$ git clone https://github.com/infinigence/FlashOverlap.git
$ cd FlashOverlap
$ git submodule update --init --recursiveInstall PyTorch and other required packages through pip or conda:
$ pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
$ pip install numpy==2.1.2, pandas==2.2.3, setuptools==75.8.0Before compiling, generate the GEMM instances:
$ mkdir ./configs
$ cd ./tool
$ python generate_instances.pyThis repo uses cmake (>=3.18) for compiling:
$ cmake -B build
$ cmake --build build -jThen the operators are registered as torch.class, and in Python code, the .so should be included whenever the operators are used.
torch.ops.load_library("../build/lib/libst_pybinding.so")M, N % 128 == 0).
.
├── cmake
│ └── Modules
│ └── FindNCCL.cmake
├── configs // To store GEMM and overlapping configs
├── example
│ ├── correctness.py // Check correctness of GEMM+AllReduce+RMSNorm
├── src
│ ├── 3rdparty
│ ├── gemm // CUTLASS GEMM Wrappers
│ │ ├── gemm.cu
│ │ └── gemm.h
│ ├── inc // Instantiate templated GEMMs
│ ├── overlap // Source files for signal+reorder
│ ├── rmsnorm // Source files for reorder+RMSNorm
│ ├── tiling // Tiling definition
│ ├── baseline_impl.cu // Baseline implementation class
│ ├── baseline_impl.h
│ ├── CMakeLists.txt
│ ├── nccl_utils.cu // NCCL id generation function
│ ├── nccl_utils.h
│ ├── overlap_impl.cu // Overlap implementation class
│ ├── overlap_impl.h
│ ├── pybind.cpp
│ └── wait.cuh // Signal kernel
├── test
│ └── test.py
├── tool
│ └── generate_instances.py // Generate templated GEMMs
├── tune
│ ├── bandwidth.py // Bandwidth test for predictive search
│ ├── gen_config.py // Generate GEMM configs based on CUTLASS profiler
│ ├── profile_config.py // Customized profiler
│ └── search.py // Exhausitive search and predictive search
└── CMakeLists.txt
Currently the repo supports two ways to generate the proper configs for GEMMs for better performance. Only one GPU is needed for this operation.
- Make sure the
./configsdir is created.
$ cd tune- Using the CUTLASS Profiler. Follow the README and write the profiling results in
$CSV_PATH/*.csv. Then, generate the.jsonfile in configs.
$ python gen_config.py --m $M --n $N --k $K --path $CSV_PATH- Using the customized profiler for a specific shape. The profiling process finishes within minutes. (This method has not been evaluated on RTX 4090 and RTX 3090 yet, will be updated soon.)
$ python profile_config.py --m $M --n $N --k $KTune the wave group size. Note multiple GPUs are needed in this program and the enviroment variable CUDA_VISIBLE_DEVICES must be set, as we use the spawn method (torch.multiprocessing.spawn) and the rank and world size are explicitly determined.
- The repo provides both the exhaustive and predictive search methods, and the latter is recommended when
MxN>4096x4096. If the predictive method is chosen, please generate the bandwidth curve first. Given GPU and communication primitive, the bandwidth curve needs only one generation.
$ CUDA_VISIBLE_DEVICES=0,1 python bandwidth.py --comm_op all_reduce- Two search methods share the same script,
--predictive_searchshould be specified if used.
$ CUDA_VISIBLE_DEVICES=0,1 python search.py --m $M --n $N --k $K --comm_op all_reduce --predictive_search True- The generated solution is written into the corresponding
.jsonfile.
Open the test dir and run the script.
$ cd ./test
$ CUDA_VISIBLE_DEVICES=0,1 python test.py --m $M --n $N --k $K- Open the example dir.
$ cd ./example
- Evaluate the correctness of GEMM+AllReduce+RMSNorm. The RMSNorm must be included as the tile order is corrected in the kernel.
$ CUDA_VISIBLE_DEVICES=0,1 python correctness.py --m $M --n $N --k $K- We define the
ReorderRMSNormclass inRMSNorm.pyand theOverlapRowParallelLayerclass inRowParallelLayer.py, which can replace theRMSNormclass andRowParallelLayerclass, respectively. It's a simple example of usage in end-to-end inference or training.
@misc{hong2025flashoverlap,
title={FlashOverlap: A Lightweight Design for Efficiently Overlapping Communication and Computation},
author={Ke Hong, Xiuhong Li, Minxu Liu, Qiuli Mao, Tianqi Wu, Zixiao Huang, Lufang Chen, Zhong Wang, Yichong Zhang, Zhenhua Zhu, Guohao Dai, Yu Wang},
year={2025},
eprint={2504.19519},
archivePrefix={arXiv},
primaryClass={cs.DC}
}