csarofeen/pytorch

combined inner outer reduction used in layer norm backward

Opened this issue ยท 0 comments

๐Ÿš€ The feature, motivation and pitch

combine inner and outer reduction into one kernel.

  1. do partial outer reduction while blocks are looping over outer domain doing block inner reduction.
  2. write result of partial outer reduction to gmem
  3. sync and reload from gmem
  4. remap parallel pattern to finalized outer reduciton.

used in ln_backward.

Alternatives

No response

Additional context

No response