Small damping hyper-parameter leads to diverge in training loss
Opened this issue · 2 comments
Hello @emtiyaz @kazukiosawa ! Thank you for this great work and your research!
I am trying to follow your work and apply it to a classification problem, hoping that uncertainty estimation will help me reduce OOD misclassifications.
I first tried to apply VOGN but after a suggestion from your side in the parallel issue (#6) switched to OGN, trying to tune hyper-parameters.
However, I noticed that putting damping parameter low (1e-5 or lower) leads to divergence and the training loss grows instead of decreasing.
I observed similar behavior on the Imagenet training (which I used as a test bed to verify my assumptions), raising damping to 1e-3 or even 1e-1 helps to stabilize training and convergence.
I understand from your research that it should be other way around and small value for damping help to control close to zero eigenvalues.
I think I am missing something, can you please advice ?
My Parameters for OGN that keep training loss slowly decreasing:
"optim_name": "DistributedSecondOrderOptimizer", "optim_args": { "curv_type": "Cov", "curv_shapes": { "Conv2d": "Diag", "Linear": "Diag", "BatchNorm1d": "Diag", "BatchNorm2d": "Diag" }, "lr": 1.6e-3, "momentum": 0.9, "momentum_type": "raw", "non_reg_for_bn": true }, "curv_args": { "damping": 1e-1, "ema_decay": 0.999 },
@uryuk Thanks for your question! (and sorry for the late response)
Yes, increasing the damping parameter stabilizes the training.
To stabilize the training, you can also increase l2_reg
(the coefficient of L2 regularization). Both the damping
and (exponent moving average of) l2_reg
will be added to the diagonal elements of the curvature before inverting it in SecondOrderOptimizer
.
See these parts
https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/curv/curvature.py#L103-L104
https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/secondorder.py#L246
https://github.com/cybertronai/pytorch-sso/blob/ab71354440600d14cfa276b3decbc8ec54122ce8/torchsso/curv/curvature.py#L230-L232
Also, see here to understand the relationship between l2_reg
of SecondOrderOptimizer
and the parameters of VIOptimizer
when you want to try VOGN instead of OGN.
https://github.com/cybertronai/pytorch-sso/blob/master/torchsso/optim/vi.py#L77
I hope this helps you.
@kazukiosawa thank you very much for the answer. I will try to experiment with l2_reg more, the relationship is clear.