Kernel Fusion using torch.jit
xrsrke opened this issue · 3 comments
xrsrke commented
Fuse some popular functions and automatically replace modules in an existing 🤗 transformers model with their corresponding fusion module
APIs
from pipegoose.nn import fusion
# and other parallelism ...
model = TensorParallel(model, parallel_context).parallelize()
model.fuse()
# or selective kernel fusion
model.fuse([fusion.LayerNorm, fusion.Attention])
TODOs
- Fuse bias addition and GeLU [link]
- Fuse bias addition and dropout [link]
- Use torch.fx to detect modules that can be fused and replace them with the fused version.
Reading (could be ignored)
- OSLO’s kernel fusion [[link]](https://github.com/tunib-ai/oslo/blob/88dcca0441a605b462bf825cb0104bc692f14c57/oslo/fused_kernels_utils.py#L259)
- GPT-NeoX’s kernel fusion [[link]](https://github.com/EleutherAI/gpt-neox/blob/b02d98932f95fe0500c28698b38acb175e92e980/megatron/model/activations.py#L27)
sami-bg commented
Can you assign this to me? I'd like to give this a shot
xrsrke commented
@sami-bg Check out torch.fx. We could use it to detect modules in a transformers model that can be fused and replace them with the fused version: model.transformers.blocks[0].dropout = fused_dropout
But we don't do it manually. Check out this tutorial for how to use torch.fx: https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html