TRI-ML/KP2D

code can not work

gf9276 opened this issue · 6 comments

Thank you for sharing the code, but when I configure the environment to run the code, loss will become Nan in a short time. I guess the problem may be near the triple loss function. Do you know the reason?

Hi, sorry to hear about your issues - are you using our docker?

Thanks for your reply. This is my mistake. I didn't use your docker. At first, I ran the program on pytorch1.8.0 without any warning, but the above problems will occur. Later, I used pytorch1.6.0, and the above problems disappear. I don't know why.

Hi, sorry to hear about your issues - are you using our docker?

Thanks for your reply. This is my mistake. I didn't use your docker. At first, I ran the program on pytorch1.8.0 without any warning, but the above problems will occur. Later, I used pytorch1.6.0, and the above problems disappear. I don't know why.

Hi, sorry to hear about your issues - are you using our docker?

This may not be a problem with the pytorch version. The above error occurs when I run the program on 3080 (pytorch 1.8.0, cudnn111), but everything is normal when I run the program on 2080ti (pytorch 1.8.0, cudnn111). It feels a little strange

Ahh that is strage. We only tested this using the environment from the docker and running on TitanXp or on V100. You get the issue even when running in docker but on the 3080?

Sorry, not yet.

Ahh that is strage. We only tested this using the environment from the docker and running on TitanXp or on V100. You get the issue even when running in docker but on the 3080?

Hi, so I had the same prolem and for me the solution was to add a small value to the calculations in the loss functions such as in the io loss I added: dmat = torch.sqrt(2 - 2 * torch.clamp(dmat, min=-1, max=1)+1e-8)

Also in the function build_descriptor_loss:

    ref_desc = ref_desc.div(torch.norm(ref_desc+1e-8, p=2, dim=0))
    tar_desc = tar_desc.div(torch.norm(tar_desc+1e-8, p=2, dim=0))

Hope this might help. If there is still a problem you can use torch.autograd.set_detect_anomaly(True) to find the calculation where Nan are returned.