Question about FP8 matmul coverage in FP8-LM
Closed this issue · 2 comments
Hello, I appreciate your pioneering work and believe this is the promising direction for future LLM.
As far as I read the code, msamp
provides FP8 training by
torch.nn.functional
is overridden bymsamp.nn._FP8GemmFunction
to execute FP8 matmul via TransformerEngine API.msamp.te.TeReplacer
ormsamp.nn.LinearReplacer
is called to override the model's submodules to FP8 training compatible instances such asFP8Linear
.
I also read through MS-AMP-Example, but I don't know about the following points of FP8-LM implementation.
- Are matmul in mult-head attention or flash attention modules executed in fp8?
- Does the input and positional embedding layer remain fp16 for matmul input?
Thank you.
Hi @stakahashy, thanks for your attention to our work!
Regarding your questions,
-
Are matmul in mult-head attention or flash attention modules executed in fp8?
The two linear projections (QKV=proj1(x), output=proj2(SV), where S is the attention score) in the multi-head attention are executed in FP8, but flash attention (SV = flash_attn(QKV)) is executed in BF16 or FP16. Besides, the two linears in MLP are also executed in FP8.
-
Does the input and positional embedding layer remain fp16 for matmul input?
Yes. It is executed in FP16 or BF16.
@wkcn Thank you very much for the precise answers!