Azure/MS-AMP

Why does using msamp decrease throughput

forevergj opened this issue · 4 comments

Why does the throughput of the CIFAR10 example in the project decrease when using deepspeed with MSAMP enable? Whether it is a forward process or a backward calculation process
MS-AMP enable:
image
MS-AMP close:
image

wkcn commented

Hi @forevergj , thanks for your attention to our work!

When using MS-AMP, the quantization (BF16 -> FP8) and de-quantization (FP8 -> BF16) operations take extra time cost. It decreases the training throughput when training small models (< 1B parameters).

MS-AMP improves the throughput when training large models with more than 1B parameters.

I am running a 1.8B model and the throughput drops by nearly half. Is this related to any operator?
What is the purpose of the process of de-quantization (FP8 -> BF16)?How to adapt to H800?

MS-AMP improves the throughput when training large models with more than 1B parameters.

wkcn commented

The de-quantization happens on the computation of FP8 Tensors, including FP8 weight gradients (grad = grad_lp.float()) and optimization states exp_avg = exp_avg_lp.float(); exp_avg_sq = exp_avg_sq_lp.float() (https://github.com/Azure/MS-AMP/blob/main/msamp/optim/adamw_base.py#L274).

It is adapted for both H100 and H800.

tks a lot!!