This is official code for Improving SAM Requires Rethinking its Optimization Formulation accepted at ICML 2024.
The lower-bound loss function for the maximizer defined in the paper is
Function
- For BiSAM (tanh), we set
$\phi(x)=\tanh(\alpha x)$ . - For BiSAM (-log), we set
$\phi(x) = -\log(1 + e^{(\gamma-x)}) + 1$ where$\gamma=\log(e-1)$ .
The implementation of these 2 function can be found in loss_fnc.py
conda create -n bisam python=3.8
conda activate bisam
# On GPU
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -r requirements.txt
python train.py
-
This code contains SAM, BiSAM (-log), BiSAM (tanh) on CIFAR-10/CIFAR-100 listed in Table 1&2 in the paper. Use
--opt bisam_log
to modify it.--opt sam bisam_log bisam_tanh optimizer SAM BiSAM (-log) BiSAM (tanh) -
Example scripts:
python train.py --optim bisam_log --rho 0.05 --epochs 200 --learning_rate 0.1 --model resnet56 --dataset cifar10