Dao-AILab/flash-attention

v1 algorithm typo in v2 paper

andportnoy opened this issue · 0 comments

Hi @tridao,

I believe the v1 algorithm statement in the v2 paper needs the following minor fix in the last equation:

diff --git a/src/background.tex b/src/background.tex
index 0f496e8..157c818 100644
--- a/src/background.tex
+++ b/src/background.tex
@@ -129,7 +129,7 @@ rescale to get the right output at the end:
   m^{(2)} &= \max(m^{(1)}, \mathrm{rowmax}(\vS^{(2)})) = m \\
   \ell^{(2)} &= e^{m^{(1)} - m^{(2)}} \ell^{(1)} + \mathrm{rowsum}(e^{\vS^{(2)} - m^{(2)}}) = \mathrm{rowsum}(e^{\vS^{(1)} - m}) + \mathrm{rowsum}(e^{\vS^{(2)} - m}) = \ell \\
   \tilde{\vP}^{(2)} &= \diag(\ell^{(2)})^{-1} e^{\vS^{(2)} - m^{(2)}} \\
-  \vO^{(2)} &= \diag(\ell^{(1)} / \ell^{(2)})^{-1} \vO^{(1)} + \tilde{\vP}^{(2)} \vV^{(2)} = \diag(\ell^{(2)})^{-1} e^{s^{(1)} - m} \vV^{(1)} + \diag(\ell^{(2)})^{-1} e^{s^{(2)} - m} \vV^{(2)} = \vO.
+  \vO^{(2)} &= \diag(\ell^{(1)} / \ell^{(2)}) e^{m^{(1)} - m^{(2)}} \vO^{(1)} + \tilde{\vP}^{(2)} \vV^{(2)} = \diag(\ell^{(2)})^{-1} e^{s^{(1)} - m} \vV^{(1)} + \diag(\ell^{(2)})^{-1} e^{s^{(2)} - m} \vV^{(2)} = \vO.
 \end{align*}
 
 We show how \sysnameone uses online softmax to enable tiling

I removed the $-1$ exponent on the diagonal term, and added the $e^{m^{(1)} - m^{(2)}}$ correction term.

Before:
$$\mathbf{O}^{(2)} = \text{diag}\left(l^{(1)}/l^{(2)}\right)^{-1} \mathbf{O}^{(1)} + \mathbf{\tilde{P}}^{(2)}\mathbf{V}^{(2)}$$
After:
$$\mathbf{O}^{(2)} = \text{diag}\left(l^{(1)}/l^{(2)}\right) e^{m^{(1)}-m^{(2)}} \mathbf{O}^{(1)} + \mathbf{\tilde{P}}^{(2)}\mathbf{V}^{(2)}$$
Original rendered algorithm for reference:
flash-attention-v1

I got the LaTeX source from arXiv.