The PyTorch Implementation of Stable Weight Decay.
The algorithms are proposed in the paper:
"Stable Weight Decay Regularization".
We proposed the Stable Weight Decay (SWD) method to fix weight decay in modern deep learning libraries.
-
SWD usually makes significant improvements over both L2 regularization and decoupled weight decay.
-
Simply fixing weight decay in Adam by SWD, with no extra hyperparameter, can usually outperform complex Adam variants, which have more hyperparameters.
-
SGD with Stable Weight Decay (SGDS) also often outperforms SGD with L2 regularization.
Python 3.7.3
PyTorch >= 1.4.0
You may use it as a standard PyTorch optimizer.
import swd_optim
optimizer = swd_optim.AdamS(net.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=5e-4, amsgrad=False)
Dataset | Model | AdamS | SGD M | Adam | AMSGrad | AdamW | AdaBound | Padam | Yogi | RAdam |
---|---|---|---|---|---|---|---|---|---|---|
CIFAR-10 | ResNet18 | 4.910.04 | 5.010.03 | 6.530.03 | 6.160.18 | 5.080.07 | 5.650.08 | 5.120.04 | 5.870.12 | 6.010.10 |
VGG16 | 6.090.11 | 6.420.02 | 7.310.25 | 7.140.14 | 6.480.13 | 6.760.12 | 6.150.06 | 6.900.22 | 6.560.04 | |
CIFAR-100 | DenseNet121 | 20.520.26 | 19.810.33 | 25.110.15 | 24.430.09 | 21.550.14 | 22.690.15 | 21.100.23 | 22.150.36 | 22.270.22 |
GoogLeNet | 21.050.18 | 21.210.29 | 26.120.33 | 25.530.17 | 21.290.17 | 23.180.31 | 21.820.17 | 24.240.16 | 22.230.15 |
If you use Stable Weight Decay in your work, please cite "Stable Weight Decay Regularization".
@article{xie2020stable,
title={Stable Weight Decay Regularization},
author={Xie, Zeke and Sato, Issei and Sugiyama, Masashi},
journal={arXiv preprint arXiv:2011.11152},
year={2020}
}