Why does using msamp decrease throughput
forevergj opened this issue · 4 comments
Hi @forevergj , thanks for your attention to our work!
When using MS-AMP, the quantization (BF16 -> FP8) and de-quantization (FP8 -> BF16) operations take extra time cost. It decreases the training throughput when training small models (< 1B parameters).
MS-AMP improves the throughput when training large models with more than 1B parameters.
I am running a 1.8B model and the throughput drops by nearly half. Is this related to any operator?
What is the purpose of the process of de-quantization (FP8 -> BF16)?How to adapt to H800?
MS-AMP improves the throughput when training large models with more than 1B parameters.
The de-quantization happens on the computation of FP8 Tensors, including FP8 weight gradients (grad = grad_lp.float()
) and optimization states exp_avg = exp_avg_lp.float(); exp_avg_sq = exp_avg_sq_lp.float()
(https://github.com/Azure/MS-AMP/blob/main/msamp/optim/adamw_base.py#L274).
It is adapted for both H100 and H800.
tks a lot!!