Dao-AILab/flash-attention

Backprop through LSE

abf149 opened this issue · 1 comments

Hello, I would like to PR a new feature, which allows FlashAttention to support backpropagation through log-sum-exponent (LSE).

In other words, the mha_bwd function signature (in pseudocode) is currently

mha_bwd(dout,q,k,v,out,softmax_lse,dq,dk,dv, //etc. )

but would become

mha_bwd(dout,dsoftmax_lse,q,k,v,out,softmax_lse,dq,dk,dv, //etc. )

Motivation: There is a scenario where I am training an LM & I want a penalty term against large LSE values in my loss function, to prevent attention scores from getting too large. So my loss function (pseudocode) is something like

cross_entropy(LM output) + lambda*sum(LSE over all attention layers)

where cross_entropy(LM output) is the usual LM loss, lambda is a tuning parameter, and sum(LSE over all attention layers) is a term that tries to minimize the LSE of all attention layers.

The first term introduces a dout error which must be backpropagated through each attention layer. The second term introduces a dsoftmax_lse which must be backpropagated through each attention layer.

The problem is that there is currently no way to feed dsoftmax_lse to the backward kernel, only dout.

A year or so ago (probably in the FlashAttention1 era) I implemented mha_bwd() with support for dsoftmax_lse, as part of a project.

I think it could be valuable if I merged my changes into FlashAttention2 and then made a PR, so that other people may utilize this capability to backprop through LSE.

Do you have guidelines for new contributors to this repo?

Look forward to this PR!