Loss Nan when run the post train.
Closed this issue · 5 comments
Thanks for this valuable issue! I think the problem is caused by DDP.
I tried to run the experiments with your Env (Python 3.7, PyTorch 1.11.0) and found the same problem. I am not sure if the bugs are related to the log info "INFO:root: reducer buckets have been rebuilt in this iteration." And this blog said that it's caused by PyTorch with version >= 1.7.0. After changing the torch version to 1.5.1, I no longer find the NaN loss problem.
I further tested torch1.7 and found the same problem, but torch1.6 is okay. Previously we tried torch1.7 with a single GPU, which is okay, too. So I have three suggestions for avoiding this issue, and you can choose one of them to solve the problem:
(1) (Use lower torch version) You can downgrade your torch version to 1.5 or 1.6. In this way, you don't need to modify the training script.
(2) (Gradient Accumulation) Since the problem is caused by DDP, you can avoid it by using only one GPU. The simple gradient accumulation strategy can help to enlarge the batch size. Specifically, You can set gradient_accumulation_steps
as 2 in the training script.
(3) (Use larger GPU) If you have GPUs with memory size larger than ~20GB, you can run the post-training on 1 GPU and simply double the per_device_train_batch_size
in the training script.
Thanks again for your issue. We will update the readme for this repo to address this bug soon :)
Thanks for your rigorous pieces of advice! It's helpful for us, we will try it again.
Sorry to bother you again, I have a technical question I would like to ask. Maybe just a trick.
When reproducing the CPT code, it is a little troublesome that there are too many operating parameters in the bash script, and it is difficult to enter the debugging mode in Pycharm since the debugging mode must copy all the parameters to the Pycharm setting box, as shown in the figure below.
In this way, even minor setting changes have to be manipulated in the setting box. It's not that grace I think. So I want to know how you guys debug long argument python programs. Use any plugins?
Sorry to bother you again, I have a technical question I would like to ask. Maybe just a trick.
When reproducing the CPT code, it is a little troublesome that there are too many operating parameters in the bash script, and it is difficult to enter the debugging mode in Pycharm since the debugging mode must copy all the parameters to the Pycharm setting box, as shown in the figure below.
In this way, even minor setting changes have to be manipulated in the setting box. It's not that grace I think. So I want to know how you guys debug long argument python programs. Use any plugins?
You can run the code with bash scripts instead of Pycharm with debugging mode. For debugging, I usually use pdb to set breakpoints for python code.
OK, thanks!