Getting "nans" during training on UCF101 dataset
akileshbadrinaaraayanan opened this issue · 4 comments
Hi,
Thanks for providing such a well-written and nicely commented code. We are trying to run your code on UCF101 and Sports1m dataset as the authors of the original paper report results on these datasets.
However, when we are training with: LRATE_G = 0.00004 and LRATE_D = 0.02, we are getting "nans" for Global Loss, PSNR Error and Sharpdiff Error in g_model.py
Is there any logical way to adjust the learning rate of Generator and Discriminator to prevent the above issue?
Hi Akilesh,
You may want to try the paper's original learning rates: "The learning rate ρG starts at 0.04 and is reduced over time to 0.005... The network is trained by setting the learning rate ρD to 0.02." However, I usually find NaN is a result of the learning rate being too high. I can't imagine you'd want to go any lower than 0.00004, but that may be worth a shot too.
Can I ask how you went about collecting the Sports1m data off of YouTube? I landed on the Ms. Pac-Man data because it seemed unreasonable to try downloading 1M YouTube videos through a script.
Also, make sure you try on the gradient-bug branch. The code on the master branch may be killing the gradient from adversarial network, so I've been trying to fix that on the other branch
Hi Matt,
Thanks for your suggestion. I started with the "gradient bug" branch, I noticed a small error in resize function (It should be tf.image.resize_images(inputs, [scale_height, scale_width]) and not tf.image.resize_images(inputs, scale_height, scale_width) ). I am able to train till 3000 odd steps with the above suggested learning rates, but again getting "NaN" after this stage. Do you have any other practical hacks in training this GAN to converge?
I am currently working only with UCF101 dataset, even I felt it unreasonable to download 1M YouTube videos through a script.
Thanks
Akilesh
Did you try reducing the discriminator learning rate to 0.002 or 0.0002? Seems like it's definitely an issue with the discriminator diverging, so that could help.
You may want to check out OpenAI's paper "Improved Techniques for Training GANs". This Reddit thread may have some helpful tips as well.