Sparse Transformer Inference
This repo provides a pytorch extension that speedup transformer inference with fixed structured sparsity.
The end-to-end speedup & memory profiling can be obtained with end_to_end.py
.
- To profile the execution time of sparse transformer, launch
python3 end_to_end.py --model sparse
with nsight system. - To profile the execution time of dense transformer, launch
python3 end_to_end.py --model dense
with nsight system. - To profile the memory of sparse transformer, launch
python3 end_to_end.py --model sparse --mem
with nsight system. - To profile the memory of dense transformer, launch
python3 end_to_end.py --model dense --mem
with nsight system.
Dependencies
We generate the sparse mask with scipy.sparse
. The pytorch version is 1.8.1+cu111
. The memory profiling is based on pytorch_memlab
, and we annotate our program with nvtx
.
To build the custom kernels, please use the src/install.sh
. As our kernels target on the V100 GPU's tensor core architecture, currently only sm70
is supported.