Code for the paper: On the training dynamics of deep networks with L_2 regularization
This code trains a Wide ResNet on different datasets and includes the AutoL2
algorithm described in the paper.
Implemented by Aitor Lewkowycz, based on code by Sam Schoenholz.
Requirements can be installed from requirements.txt. It is made to work on TPUs. Can also work on GPU by adding -noTPU and installing the GPU jaxlib package of https://github.com/google/jax.
Figure 1a.
for L2 in L2LIST:
do
python3 jax_wideresnet_exps.py -L2=$L2 -epochs=200 -std_wrn_sch
python3 jax_wideresnet_exps.py -L2=$L2 -physicalL2 -epochs=0.02 -std_wrn_sch # This is evolved for a time 0.02/eta/lambda=0.1/lambda epochs.
done
Figure 1b is generated by comparing the performance of models with our prediction.
for L2 in L2LIST:
do
python3 jax_wideresnet_exps.py -L2=$L2 -epochs=2000
done
To obtain the t* prediction, we run the following.
python3 jax_wideresnet_exps.py -L2=0.01 -epochs=2
Figure 1c: Evolve with lr=0.2
for 200 epochs with L0=0.1
and L2_sch vs L2=0.0001
.
python3 jax_wideresnet_exps.py -L2=0.1 -L2_sch
python3 jax_wideresnet_exps.py -L2=0.0001 -noL2_sch
The Wide ResNet experiments in Figure 2 are similar.
for lr in LRLIST:
do
for L2 in L2LIST:
do
python3 jax_wideresnet_exps.py -L2=$L2 -physicalL2 -epochs=0.1 -nomomentum -noaugment
done
done