tatsu-lab/alpaca_farm

[Reward Model Training] Inconsistent accuracy caused by flash-attention

nbl97 opened this issue · 1 comments

nbl97 commented

Many thanks for your excellent work~
When training the reward model, I found that flash attn affected the final accuracy. I followed the README exactly to reproduce sft10k and then use it to train a reward model. I used all the default parameters, but found that using flash attn or not made a 3.4% difference in accuracy (w/ flash-attn 60% v.s. w/o flash-attn 56.6%). The results are shown in the figure below. The pink plot is for w/o flash-attn and the blue plot is for w/ flash-attn. Is this accuracy gap normal?

I noticed that only the inference consistency was tested in tests/test_flash_llama.py. Did the author test the back-propagation?

image

Thanks for raising this issue. The difference in accuracy you get there is more likely due to the randomness in training than any other reason.

Due to finite arithmetic, results with and without FlashAttention will differ slightly. This is because the underlying algorithm and implementation differ for FlashAttention and standard attention. The backward pass of FlashAttention is implemented in the original flash-attn codebase and has been appropriately tested; PyTorch's autograd engine handles the rest. So it's unlikely there's a bug with the backward pass.

I did another ablation with and without FlashAttention. From the figure below, you can see that the difference is small and not statistically significant. Each curve is based on three independent runs (note I fixed the validation set to be the same for all runs). The full experimental log can be found here.

Screenshot 2023-06-11 at 1 24 49 PM Screenshot 2023-06-11 at 1 24 45 PM