Backprop through LSE
abf149 opened this issue · 1 comments
Hello, I would like to PR a new feature, which allows FlashAttention to support backpropagation through log-sum-exponent (LSE).
In other words, the mha_bwd function signature (in pseudocode) is currently
mha_bwd(dout,q,k,v,out,softmax_lse,dq,dk,dv, //etc. )
but would become
mha_bwd(dout,dsoftmax_lse,q,k,v,out,softmax_lse,dq,dk,dv, //etc. )
Motivation: There is a scenario where I am training an LM & I want a penalty term against large LSE values in my loss function, to prevent attention scores from getting too large. So my loss function (pseudocode) is something like
cross_entropy(LM output) + lambda*sum(LSE over all attention layers)
where cross_entropy(LM output) is the usual LM loss, lambda is a tuning parameter, and sum(LSE over all attention layers) is a term that tries to minimize the LSE of all attention layers.
The first term introduces a dout
error which must be backpropagated through each attention layer. The second term introduces a dsoftmax_lse
which must be backpropagated through each attention layer.
The problem is that there is currently no way to feed dsoftmax_lse
to the backward kernel, only dout.
A year or so ago (probably in the FlashAttention1 era) I implemented mha_bwd() with support for dsoftmax_lse, as part of a project.
I think it could be valuable if I merged my changes into FlashAttention2 and then made a PR, so that other people may utilize this capability to backprop through LSE.
Do you have guidelines for new contributors to this repo?
Look forward to this PR!