LTH14/mar

Faster training with fp16 or bf16

Closed this issue · 5 comments

Hi authors,
Thank you for sharing the code! I am wondering whether the training process could be accelerated by fp16 or bf16. Is there any potential risks in doing this?

Additionally, could you clarify the meaning of "# preds" as mentioned in Table 1 of your paper?
截屏2024-09-02 14 54 33

Looking forward to your reply!

We currently use torch.amp.autocast() to speed up the training, and occasionally observe Nan loss because of this. Using fp16/bf16 could potentially result in a similar issue.

The number of predictions means the number of tokens predicted in one auto-regressive iteration. Similar to MaskGIT, our framework can predict either one or multiple tokens in each auto-regressive iteration.

Thank you for your explanation!

We currently use torch.amp.autocast() to speed up the training, and occasionally observe Nan loss because of this. Using fp16/bf16 could potentially result in a similar issue.

The number of predictions means the number of tokens predicted in one auto-regressive iteration. Similar to MaskGIT, our framework can predict either one or multiple tokens in each auto-regressive iteration.

Hi, thank you for your excellent work. I've also encountered NaN loss values when using fp16. I've attempted various solutions, such as smaller gradient clipping norm, reducing the learning rate, and increasing the epsilon value for the optimizer, but none of these measures resolve this issue. Could you please offer some advice on how to tackle this issue?

LTH14 commented

@darkliang Thanks for your interest. You could also consider reducing the batch size -- large batch training is one of the major reasons for NAN loss. Also, use bf16 if you can -- although it could also cause NAN, it is less likely than fp16. If none of the above works, fp32 might be the only solution.

Thank you, I'll give reducing the batch size a try.
I'm unable to use bf16 since I'm training on V100 GPUs.