/liegroups

Code for "The Lie-Group Bayesian Learning Rule", AISTATS 2023.

Primary LanguagePythonGNU General Public License v3.0GPL-3.0

liegroups

Code for The Lie-Group Bayesian Learning Rule, E. M. Kiral, T. Möllenhoff, M. E. Khan, AISTATS 2023.

installation and requirements

The code requires JAX and various other standard dependencies such as matplotlib and numpy; see the 'requirements.txt'.

To train on TinyImageNet, you will need to download the dataset from here and extract it into the datasetfolder directory (see the 'data.py' file).

examples

tanh-MLP on MNIST

To run the additive and multiplicative learning-rule proposed in the paper on a tanh-MLP & MNIST dataset, you can use the following commands:

Running the additive rule:

python3 train.py --optim additive --model mlp --alpha1 0.05 --epochs 25 --noise gaussian --noiseconfig 0.001 --batchsize 50 --priorprec 0 --mc 32 --warmup 0

This should train to around 98%.

Running the multiplicative rule (the code currently only supports Rayleigh-noise):

python3 train.py --optim multiplicative --temperature 0.006 --alpha1 50 --beta1 0.9 --model mlp --noise rayleigh --batchsize 50 --mc 32 --epochs 25 --priorprec 0 --warmup 0

This should also train to around 98%.

first layer filter visualizations

To reproduce the figures visualizing the filters, run the following (after training the tanh-MLP networks using the above commands):

python3 plot_filters.py --resultsfolder results/mnist_mlp/additive/run_0 

additive filters

python3 plot_filters.py --resultsfolder results/mnist_mlp/multiplicative/run_0 

multiplicative excitatory multiplicative inhibitory

The above filter images are saved by the script in the resultsfolder as png files.

multiplicative updates for CNN architecture

Training a LeNet-like CNN on CIFAR-10 with the multiplicative updates:

python3 train.py --optim multiplicative --temperature 0.001 --alpha1 100 --beta1 0.9 --model cnn --noise rayleigh --batchsize 100 --mc 10 --epochs 180 --priorprec 1 --dataset cifar10 --multinitoffset 0.001 --dafactor 4 --warmup 3

Testing,

python3 test.py --resultsfolder results/cifar10_cnn/multiplicative/run_0 --testbatchsize 2000

should give the following results:

results at g:
  > testacc=86.12%, nll=0.7622, ece=0.1010
results at model average (32 samples):
  > testacc=87.24%, nll=0.4500, ece=0.0151

CIFAR and TinyImageNet

To run the affine and additive learning rule on CIFAR and TinyImageNet dataset, you can use the following commands:

Affine update rule (w/ Gaussian noise):

python3 train.py --optim affine --temperature 1 --alpha1 1.0 --alpha2 0.05 --beta1 0.8 --beta2 0.999 --dataset cifar10 --model resnet20 --noise gaussian --batchsize 200 --mc 1 --noiseconfig 0.005 --batchsplit 1 --epochs 180 --priorprec 25

Running the additive update rule (w/ Gaussian noise):

python3 train.py --optim additive --alpha1 0.5 --beta1 0.8 --dataset cifar10 --model resnet20 --noise gaussian --batchsize 200 --mc 1 --noiseconfig 0.005 --batchsplit 1 --priorprec 25

To evaluate ECE, nll and accuracy of the trained models, run the following command specifying the folder where the results have been saved:

python3 test.py --resultsfolder results/cifar10_resnet20/affine/run_0

This produces an output as follows, cf. also Table 2 in the paper:

results at g:
  > testacc=91.96%, nll=0.2887, ece=0.0363
results at model average (32 samples):
  > testacc=92.02%, nll=0.2661, ece=0.0247

We can also evaluate our additive learning rule:

python3 test.py --resultsfolder results/cifar10_resnet20/additive/run_0

This produces an output as follows, cf. also Table 2 in the paper:

results at g:
  > testacc=92.07%, nll=0.3014, ece=0.0420
results at model average (32 samples):
  > testacc=92.21%, nll=0.2688, ece=0.0268

training with multiple MC samples

We can run the affine learning rule using multiple random samples as follows:

python3 train.py --optim affine --temperature 1 --alpha1 1.0 --alpha2 0.05 --beta1 0.8 --beta2 0.999 --dataset cifar10 --model resnet20 --noise gaussian --batchsize 200 --mc 3 --noiseconfig 0.01 --batchsplit 1 --epochs 180 --priorprec 25

This will be more computationally expensive but leads to improved results:

results at g:
  > testacc=92.13%, nll=0.2747, ece=0.0348
results at model average (32 samples):
  > testacc=92.42%, nll=0.2403, ece=0.0099

troubleshooting

Please contact Thomas if there are issues or quesitons about the code, or raise an issue here in this github repository.