xrsrke/pipegoose

Kernel Fusion using torch.jit

xrsrke opened this issue · 3 comments

xrsrke commented

Fuse some popular functions and automatically replace modules in an existing 🤗 transformers model with their corresponding fusion module

APIs

from pipegoose.nn import fusion

# and other parallelism ...
model = TensorParallel(model, parallel_context).parallelize()
model.fuse()

# or selective kernel fusion
model.fuse([fusion.LayerNorm, fusion.Attention])

TODOs

  • Fuse bias addition and GeLU [link]
  • Fuse bias addition and dropout [link]
  • Use torch.fx to detect modules that can be fused and replace them with the fused version.

Reading (could be ignored)

Can you assign this to me? I'd like to give this a shot

xrsrke commented

@sami-bg Check out torch.fx. We could use it to detect modules in a transformers model that can be fused and replace them with the fused version: model.transformers.blocks[0].dropout = fused_dropout

But we don't do it manually. Check out this tutorial for how to use torch.fx: https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html