/tvlars

TVLARS - A Fast Convergence Optimizer for Large Batch Training

Primary LanguagePythonMIT LicenseMIT

TVLARS - Pytorch Official Implementation

Abstract

LARS and LAMB have emerged as prominent techniques in Large Batch Learning (LBL), ensuring the stability of AI training. One of the primary challenges in LBL is stabilizing convergence, where the AI agent is usually get trapped in the sharp minimizer. Addressing this challenge, a relatively recent technique, known as warmup, has been employed. However, it's worth noting that warmup lacks a strong theoretical foundation, leaving the door open for further exploration of more efficacious algorithms. In light of this situation, we conducted empirical experiments to analyze the behaviors of the two most popular optimizers in the LARS family: LARS and LAMB, with and without a warm-up strategy. Our analysis gives us a comprehension of the novel LARS, LAMB, and the necessity of a warmup technique in LBL. Building upon these insights, we propose a novel algorithm called Time Varying LARS (TVLARS), which facilitates robust training in the initial phase without the need for warm-up. Experimental evaluation demonstrates that TVLARS achieves competitive results with LARS and LAMB when warm-up is utilized while surpassing their performance without the warm-up technique.

Experiment

Setup

This work can be conducted on any platform: Windows, Ubuntu, Google Colab. In Windows or Ubuntu use the following script to create a virtual environment.

git clone https://github.com/KhoiDOO/tvlars.git
cd path/to/tvlars
python -m venv .env

The Python packages used in this project are listed below. Crucially, parquet and pyarrow are used for writing and saving .parquet file, which is a strongly compressed file for saving the DataFrame. All the packages can be installed by command pip install -r requirements.txt. If parquet does not work with your machine, consider using fastparquet instead.

matplotlib==3.7.1
numpy==1.24.3
pandas==2.0.1
parquet==1.3.1
pyarrow==12.0.0
seaborn==0.12.2
tqdm==4.65.0

Pytorch is the main package for conducting optimization calculations, whose version is 2.0.1.

Available Settings

Using python main.py -h to print out all available settings of this project. The table below show the tag as well as its related description.

TAG OPTIONS DESCRIPTION
-h, --help show this help message and exit
--bs BS batch size
--workers WORKERS Number of processor used in data loader
--epochs EPOCHS Number of epochs used in training
--lr LR initial learning rate
--seed SEED seed for initializing training
--port PORT Multi-GPU Training Port
--wd W weight decay
--ds cifar10, cifar100, tinyimagenet data set name
--model resnet18, resnet34, resnet50, effb0 model used in training
--opt adam, adamw, adagrad, rmsprop, lars, tvlars, lamb optimizer used in training
--sd None, cosine, lars-warm learning rate scheduler used in training
--dv DV [DV ...] list of devices used in training
--lmbda LMBDA delay factor used in TVLARS
--cl_epochs CL_EPOCHS epoch used in Barlow twins feature redundant removal stage
--btlmbda BTLMBDA lambda factor used in Barlow Twins
--projector PROJECTOR dimensions of top Multilayer Perceptron used in Barlow Twins
--lr_classifier LR_CLASSIFIER classifier learning rate used in Barlow Twins
--lr_backbone LR_BACKBONE backbone learning rate used in Barlow Twins
--mode clf, bt experiment mode, clf is for classification, bt is for Barlow Twins experiment

Running

For instance, the experiment of TVLARS with batch size ($\mathcal{B}$) of 512 and various delay factor ($\lambda$) values by the following expressions: Classification Experiment

python main.py --bs 512 --epochs 100 --lr 1.0 --port 7046 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 3675 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 6162 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 3930 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 7644 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 1.0 --port 5794 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 3976 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 5895 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 5014 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 6423 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 5228 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 2.0 --port 6169 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 5466 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 7422 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 6373 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 6592 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 4802 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3
python main.py --bs 512 --epochs 100 --lr 3.0 --port 7327 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3

Barlow Twins Experiment

python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 7186 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 4111 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 4356 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 7782 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 4353 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 1.0 --port 6524 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 3979 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 4969 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 3517 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 7895 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 4434 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 2.0 --port 7770 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 5348 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-06 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 4362 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 1e-05 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 6193 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.0001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 6442 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.001 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 7169 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.005 --dv 0 1 2 3 --mode bt
python main.py --bs 512 --epochs 100 --cl_epochs 1000 --lr 3.0 --port 7954 --wd 0.0005 --ds cifar10 --model resnet18 --opt tvlars --sd None --lmbda 0.01 --dv 0 1 2 3 --mode bt

Citation

@misc{do2023revisiting,
      title={Revisiting LARS for Large Batch Training Generalization of Neural Networks}, 
      author={Khoi Do and Duong Nguyen and Hoa Nguyen and Long Tran-Thanh and Quoc-Viet Pham},
      year={2023},
      eprint={2309.14053},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}