/stable-weight-decay-regularization

The PyTorch Implementation of Stable Weight Decay. The algorithms are proposed in the paper: "Stable Weight Decay Regularization".

Primary LanguagePythonMIT LicenseMIT

Stable-Weight-Decay-Regularization

The PyTorch Implementation of Stable Weight Decay.

The algorithms are proposed in the paper:

"Stable Weight Decay Regularization".

Why Stable Weight Decay?

We proposed the Stable Weight Decay (SWD) method to fix weight decay in modern deep learning libraries.

  • SWD usually makes significant improvements over both L2 regularization and decoupled weight decay.

  • Simply fixing weight decay in Adam by SWD, with no extra hyperparameter, can usually outperform complex Adam variants, which have more hyperparameters.

  • SGD with Stable Weight Decay (SGDS) also often outperforms SGD with L2 regularization.

The environment is as bellow:

Python 3.7.3

PyTorch >= 1.4.0

Usage

You may use it as a standard PyTorch optimizer.

import swd_optim

optimizer = swd_optim.AdamS(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)

Test performance

Dataset Model AdamS SGD M Adam AMSGrad AdamW AdaBound Padam Yogi RAdam
CIFAR-10 ResNet18 4.910.04 5.010.03 6.530.03 6.160.18 5.080.07 5.650.08 5.120.04 5.870.12 6.010.10
VGG16 6.090.11 6.420.02 7.310.25 7.140.14 6.480.13 6.760.12 6.150.06 6.900.22 6.560.04
CIFAR-100 DenseNet121 20.520.26 19.810.33 25.110.15 24.430.09 21.550.14 22.690.15 21.100.23 22.150.36 22.270.22
GoogLeNet 21.050.18 21.210.29 26.120.33 25.530.17 21.290.17 23.180.31 21.820.17 24.240.16 22.230.15

Citing

If you use Stable Weight Decay in your work, please cite "Stable Weight Decay Regularization".

@article{xie2020stable,
  title={Stable Weight Decay Regularization},
  author={Xie, Zeke and Sato, Issei and Sugiyama, Masashi},
  journal={arXiv preprint arXiv:2011.11152},
  year={2020}
}