tlc-pack/relax

[Tracking Issue] Relax graph-level BYOC

masahi opened this issue · 0 comments

I've been working on bringing up BYOC infra in Relax, building on the work of @sunggg and the pattern matcher work from @ganler. The ultimate goal is to make relax.vm.build(mod, "cuda") just work without tuning and with reasonable out-of-the-box performance. Also it would be the first step toward performant dynamic-shape support.

My branch is here and currently I have minimal test cases for offloading a simple subgraph to DNNL and CUTLASS. I'm going to start sending pieces from it from today.
https://github.com/tlc-pack/relax/compare/relax...masahi:codegen-cutlass?expand=1

  • Refactor RunCodegen pass to send all BYOC functions to the backend at once (rather than individually)
  • Add pattern-based partitioning pass (similar to MergeComposite in Relay)
  • Add pass to wrap and annotate the partitioned function for offloading (subsumed by #372)
  • Add DNNL backend
  • Add CUTLASS backend
  • Add pass to merge neighboring calls to functions compiled for the same external backend into one function (similar to MergeCompilerRegion in Relay, necessary for TRT)
  • Revisit TensorRT backend (originally added by #164)

Future possibilities (time permitting)

  • Add cuDNN backend (supporting Graph API)
  • Add oneDNN (aka dnnl) v3 graph API backend
  • Advanced fusion, such as fused MHA
  • Take advantage of graph-level passes (constant folding, scale axis folding, layout transformation etc) when they become available
  • Add mechanism to handle constants (recurring problems in Relay BYOC) (Initial work in #400, not sure if it is complete)
  • Improve each backend (more patterns, e2e eval etc)

@sunggg @YuchenJin @tqchen @junrushao