Make fused RMSNorm a registered op
lessw2020 opened this issue · 2 comments
lessw2020 commented
Adding this as tracking issue to unblock #181 from landing:
per @wanchaol :
IMO we should also register the fwd/bwd rmsnorm kernel as a PyTorch op, this is so that:
making it a custom op makes it compatible with PT2, which I believe it's currently graph breaking on the FusedRMSNorm path if we turn on torch.compile
it allows other components (i.e. DTensor) to provide sharding rule to this custom op so that it would compatible with the tensor parallelism
tianyu-l commented
update: Hit IMA issues for both my implementation #296 and @wconstab's #303. Working on debugging with @lessw2020 .