Azure/MS-AMP

Can I use fp8 only when the code runs to the fp8 branch?

forevergj opened this issue · 8 comments

I am using deepspeed. ops. adam FusedAdam cannot go to the fp8 branch
image
deepspeed zero stage2

wkcn commented

Hi @forevergj , thanks for your attention to our work!

The reason is that master weight, weight, weight gradient and optimizer states are all tensors with scaling factors. FusedAdam did not support the computation on scaling tensors.

Thank you for your reply!Where do the benefits of using msamp for training acceleration come from?

wkcn commented

@forevergj
MS-AMP applies low-bit data formats on master weight, weight, weight gradient and optimizer states. It saves the GPU memory to allow larger batch size for faster training speed.

MS-AMP applies low-bit data formats on master weight, weight, weight gradient and optimizer states. It saves the GPU memory to allow larger batch size for faster training speed.

1.Does it mean that the fp8 operation of msamp only occurs after backpropagation when the optimizer updates the weights. However, forward propagation does not involve fp8 operations?
2.Is the main optimization focus on gradient communication time?

wkcn commented

MS-AMP applies low-bit data formats on master weight, weight, weight gradient and optimizer states. It saves the GPU memory to allow larger batch size for faster training speed.

1.Does it mean that the fp8 operation of msamp only occurs after backpropagation when the optimizer updates the weights. However, forward propagation does not involve fp8 operations? 2.Is the main optimization focus on gradient communication time?

  1. The forward propagation and backward propagation both involve FP8.
  2. No. The acceleration benefits from FP8 matrix multiplication in the linear layers. MS-AMP reduces the GPU memory by low-precision data formats. It allows larger batch size for the further acceleration.

image
Is it similar to a transformer engine? What is the difference between msamp and transformer engine?
image
Does the image in the MSAMP paper mean that the conversion of FP8 is executed in two places. They are before the all-gather gradient and before executing Linear. This means that the output of Linear is high-precision?

wkcn commented

image Is it similar to a transformer engine? What is the difference between msamp and transformer engine? image Does the image in the MSAMP paper mean that the conversion of FP8 is executed in two places. They are before the all-gather gradient and before executing Linear. This means that the output of Linear is high-precision?

  1. MS-AMP can be combined with TransformerEngine. The difference is that MS-AMP applies FP8 data format on master weight, weight, weight gradient, optimizer states, and the communication of gradient reduction.

  2. Regarding to your second question, the output of LayerNorm is a high-precision tensor (FP16 or BF16), which is converted to an FP8 tensor with a scaling factor, as the input of the all-gather operation. The output of the all-gather operation is an FP8 tensor with a scaling factor, that is the input of the FP8 GEMM (YA). The output of the FP8 GEMM (YA) is a high-precision tensor (FP16 or BF16). Z is also a high-precision tensor, being converted to an FP8 tensor with a scaling factor.

Thank you for your reply. It's very clear