- "Penalizing Gradient Norm for Efficiently Improving Generalization in Deep Learning"[ICML2022], by Yang Zhao, Hao Zhang and Xiuyuan Hu.
- "When Will Gradient Regularization Be Harmful?"[ICML2024], by Yang Zhao, Hao Zhang and Xiuyuan Hu.
- JAX Framework Update: Upgraded the training framework to the latest version (JAX 0.4.28).
- New Paper Implementation: Integrated the implementation of our latest research paper into this repository.
- Additional Model Architectures:: Included Swin and CaiT Transformer architectures in the model list.
-
Environment Setup: This repository is built using the JAX framework. Begin by setting up the Python environment specified in the
requirements.txt
file. -
Configuration: The
config
folder contains all the configuration flags and their default values. You can add custom flags if needed by modifying these files. -
Model Architectures: The
model
folder includes various model architectures such as VGG, ResNet, WideResNet, PyramidNet, ViT, Swin and CaiT. To add custom models, follow the Flax model template and register your model using the_register_model
function in this folder. -
Dataset Pipeline: The
ds_pipeline
folder provides the dataset pipeline, based primarily on the SAM repository. Unlike SAM, this repo uses local ImageNet data instead of tensorflow_datasets. Specify the path to your local dataset folders, ensuring the folder structure is:ImageNet folder └───n01440764 │ │ *.JPEG │ └───n01443537 │ │ *.JPEG ...
-
Optimizers: The
optimizer
folder contains the optimizers, including SGD (Momentum), AdamW and RMSProp. You can add custom optimizers by modifying these files. -
Training Recipes: The
recipe
folder contains.sh
files, each corresponding to a specific model's training script. To run a training script, use the following command:bash wideresnet-cifar.sh
Alternatively, to deploy configurations directly (ensuring the config flag is in the config file), use:
python3 -m gnp.main.main --config=the-train-config-py-file --working_dir=your-output-dir --config.config-anything-else-here
Basically, gradient regularization (GR) could be understood as gradient norm penalty, where an additional term regarding the gradient norm
Gradient norm is considered as a key property that could characterize the flatness of the loss surface. By penalizing the gradient norm, the optimization is encouraged to converge to flatter minima on the loss surface. This results in improved model generalization.
Based on the chain rule, the gradient of the gradient norm is given by:
Computing the gradient of this gradient norm term directly involves the full computation of the Hessian matrix. To address this, we use a Taylor expansion to approximate the multiplication between the Hessian matrix and vectors, resulting in:
where
Notably, the SAM algorithm is a special implementation of this scheme where
GR can lead to serious performance degeneration in the specific scenarios of adaptive optimization.
Model | Adam | Adam + GR | Adam + GR + Zero-GR-Warmup |
---|---|---|---|
ViT-Ti | 14.82 | 13.92 | 13.61 |
ViT-S | 12.07 | 12.40 |
10.68 |
ViT-B | 10.83 | 12.36 |
9.42 |
With both our empirical observations and theoretical analysis, we find that the biased estimation introduced in GR can induce the instability and divergence in gradient statistics of adaptive optimizers at the initial stage of training, especially with a learning rate warmup technique which originally aims to benefit gradient statistics.
To mitigate this issue, we draw inspirations from the idea of warmup techniques, and propose three GR warmup strategies:
If you find this helpful, you could cite the papers as
@inproceedings{zhao2022penalizing,
title={Penalizing gradient norm for efficiently improving generalization in deep learning},
author={Zhao, Yang and Zhang, Hao and Hu, Xiuyuan},
booktitle={International Conference on Machine Learning},
pages={26982--26992},
year={2022},
organization={PMLR}
}
@inproceedings{zhaowill,
title={When Will Gradient Regularization Be Harmful?},
author={Zhao, Yang and Zhang, Hao and Hu, Xiuyuan},
booktitle={Forty-first International Conference on Machine Learning}
}