/torch-blocksparse

Block-sparse primitives for PyTorch

Primary LanguagePython

Torch-Blocksparse

Block-sparse operations for PyTorch

Current State

The following functions are supported:

Matrix Multiplication: SPARSE = op(DENSE) x op(DENSE)
Matrix Multiplication: DENSE = op(SPARSE) x op(DENSE)
Matrix Multiplication: DENSE = op(DENSE) x op(SPARSE)
Softmax: Sparse = Softmax(Sparse)

where op() is identity or transposition.

The following modules are supported:

Sparse MultiHead Attention (https://arxiv.org/abs/1904.10509)

Inputs are FP32 or FP16 (with tensor cores).

Installation

Torch-Blocksparse depends on CUDA 10.1 and the Triton language and compiler:

sudo apt-get install llvm-8-dev;
pip install -e "git+https://github.com/ptillet/triton.git#egg=triton&subdirectory=python"

And run the tests:

python tests/test.py

The first run will take some time as all the necessary CUDA code will be JIT-compiled and cached in $HOME/.triton/cache.

You can install the package as follows:

python setup.py develop

Usage

import torch
import torch_blocksparse

# Z: non-sparse batch dimension
# H: sparse batch dimension
# M: row dimension
# N: column dimension
Z, H, M, N, K = 4, 2, 256, 512, 384
a = torch.rand((Z, H, M, K), dtype=torch.float32).cuda()
b = torch.rand((Z, H, K, N), dtype=torch.float32).cuda()
# create sparsity layout
block = 16
layout = torch.randint(0, 2, (H, M//block, N//block))
# create object for Sparse = trans(Dense) x Dense (sdd)
# some overhead there as it pre-computes look-up tables 
# internally needed by GPU kernels
dot = torch_blocksparse.SparseMatMul(layout, block, 'sdd', trans_a=True, trans_b=False)
c = dot(a, b)
# create object for Sparse = softmax(Sparse)
softmax = torch_blocksparse.SparseSoftmax(layout, block)
d = softmax(c)

Performance

Here is the performance of this package compared to OpenAI blocksparse for the DDS layout (dense = dense x sparse) with square, non-transposed inputs:

The file test.py includes simple benchmarking code.