This repo contains a JAX implementation of the IVON optimizer from our paper
Variational Learning is Effective for Large Deep Networks
Y. Shen*, N. Daheim*, B. Cong, P. Nickl, G.M. Marconi, C. Bazan, R. Yokota, I. Gurevych, D. Cremers, M.E. Khan, T. Möllenhoff
International Conference on Machine Learning (ICML), 2024 (spotlight)
ArXiv: https://arxiv.org/abs/2402.17641
Blog: https://team-approx-bayes.github.io/blog/ivon/
Tutorial: https://ysngshn.github.io/research/why-ivon/
The JAX implementation of IVON is self-contained in a single file ivon.py. It makes use of components from the optax library and can be jax.jit
-ted for maximum efficiency.
We also provide an official PyTorch implementation of the IVON optimizer. Experiments in our paper are obtained with the PyTorch implementation and their source code can be found here.
IVON optimizer can be defined via ivon.ivon()
and initialized with its .init()
method.
optimizer = ivon(
lr,
ess,
hess_init,
beta1,
beta2,
weight_decay,
# ...
)
optstate = optimizer.init(params)
Appendix A of our paper provides practical guidelines for choosing IVON hyperparameters.
Similar to the usual Optax interface, the ivon.IVON
optimizer instance returned by the construction function ivon.ivon()
has an .init()
method for initializing the optimizer states and a .step()
method to compute the gradient updates.
Additionally, ivon.IVON
provides two methods to support its stochastic variational inference updates: .sampled_params()
to draw a weight posterior sample and additionally return its generated noise, which will be used by the .accumulate()
method to estimate and accumulate the expected gradient and Hessian for the current step.
In general, a typical training step could be carried out as follows:
train_mcsamples = 1 # 1 sample is good enough, more even better
rngkey, *mc_keys = jax.random.split(rngkey, train_mcsamples+1)
# Stochastic natural gradient VI with IVON
for key in mc_keys:
# draw IVON weight posterior sample
psample, noise = optimizer.sampled_params(
key, params, optstate
)
# get gradient for this MC sample from feed-forward + backprop
mc_updates = ff_backprop(psample, inputs, target, ...)
# accumulate for current sample
optstate = optimizer.accumulate(mc_updates, optstate, noise)
# compute updates with accumulated estimations
updates, optstate = optimizer.step(optstate, params)
params = optax.apply_updates(params, updates)
Simply copy the ivon.py
file to your project folder and import it.
This file only depends on jax
and optax
. Visit their docs to see how they should be installed.
We also provide an example that has additional dependencies, see below.
The folder resnet-cifar contains an example showing how the IVON optimizer can tackle image classification tasks. This code base is adapted from the Bayesian-SAM repo with permission from the author. Check out its package dependencies here.
Go to the resnet-cifar folder, run the following commands:
SGD
python train.py --alpha 0.03 --beta1 0.9 --priorprec 25 --optim sgd --dataset cifar10
This should train to around 94.8% test accuracy.
IVON
python train.py --alpha 0.5 --beta1 0.9 --beta2 0.99999 --priorprec 25 --optim ivon --hess-init 0.5 --dataset cifar10
This should train to around 95.2% test accuracy.
In resnet-cifar run the following commands. Adapt --resultsfolder
to the actual folder which stores the training results.
SGD
python test.py --resultsfolder /results/cifar10_resnet18/sgd/run_0/
IVON
python test.py --resultsfolder /results/cifar10_resnet18/ivon/run_0/ --testmc 64
Tests should show that IVON with 64-sample Bayesian prediction gets significantly better uncertainty metrics.