Azure/MS-AMP

Question about FP8 matmul coverage in FP8-LM

Closed this issue · 2 comments

Hello, I appreciate your pioneering work and believe this is the promising direction for future LLM.

As far as I read the code, msamp provides FP8 training by

  • torch.nn.functional is overridden by msamp.nn._FP8GemmFunction to execute FP8 matmul via TransformerEngine API.
  • msamp.te.TeReplacer or msamp.nn.LinearReplacer is called to override the model's submodules to FP8 training compatible instances such as FP8Linear.

I also read through MS-AMP-Example, but I don't know about the following points of FP8-LM implementation.

  1. Are matmul in mult-head attention or flash attention modules executed in fp8?
  2. Does the input and positional embedding layer remain fp16 for matmul input?

Thank you.

Hi @stakahashy, thanks for your attention to our work!
Regarding your questions,

  1. Are matmul in mult-head attention or flash attention modules executed in fp8?

The two linear projections (QKV=proj1(x), output=proj2(SV), where S is the attention score) in the multi-head attention are executed in FP8, but flash attention (SV = flash_attn(QKV)) is executed in BF16 or FP16. Besides, the two linears in MLP are also executed in FP8.

  1. Does the input and positional embedding layer remain fp16 for matmul input?

Yes. It is executed in FP16 or BF16.

@wkcn Thank you very much for the precise answers!